{"version":3,"file":"tf.js","sources":["../node_modules/@tensorflow/tfjs-core/src/environment.ts","../node_modules/@tensorflow/tfjs-core/src/kernel_registry.ts","../node_modules/@tensorflow/tfjs-core/src/util.ts","../node_modules/@tensorflow/tfjs-core/src/profiler.ts","../node_modules/@tensorflow/tfjs-core/src/tensor_format.ts","../node_modules/@tensorflow/tfjs-core/src/tensor.ts","../node_modules/@tensorflow/tfjs-core/src/types.ts","../node_modules/@tensorflow/tfjs-core/src/tensor_util.ts","../node_modules/@tensorflow/tfjs-core/src/engine.ts","../node_modules/@tensorflow/tfjs-core/src/tape.ts","../node_modules/@tensorflow/tfjs-core/src/device_util.ts","../node_modules/@tensorflow/tfjs-core/src/flags.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/canvas_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/tex_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/webgl_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/flags_webgl.ts","../node_modules/@tensorflow/tfjs-core/src/globals.ts","../node_modules/@tensorflow/tfjs-core/src/log.ts","../node_modules/@tensorflow/tfjs-core/src/tensor_util_env.ts","../node_modules/@tensorflow/tfjs-core/src/ops/axis_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/concat_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/operation.ts","../node_modules/@tensorflow/tfjs-core/src/ops/complex_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/tensor_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/concat_split.ts","../node_modules/@tensorflow/tfjs-core/src/ops/array_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/array_ops_util.ts","../node_modules/@tensorflow/tfjs-core/src/kernel_names.ts","../node_modules/@tensorflow/tfjs-core/src/ops/add.ts","../node_modules/@tensorflow/tfjs-core/src/ops/broadcast_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/unary_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/binary_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/div.ts","../node_modules/@tensorflow/tfjs-core/src/ops/gather_nd_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/reduce_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/scatter_nd_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/slice_util.ts","../node_modules/@tensorflow/tfjs-core/src/gradients.ts","../node_modules/@tensorflow/tfjs-core/src/ops/softmax.ts","../node_modules/@tensorflow/tfjs-core/src/ops/transpose.ts","../node_modules/@tensorflow/tfjs-core/src/backends/backend.ts","../node_modules/@tensorflow/tfjs-core/src/ops/conv_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/backend_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/complex_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/array_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/non_max_suppression_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/split_shared.ts","../node_modules/@tensorflow/tfjs-core/src/backends/tile_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/topk_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/where_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/addn_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/addn_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/argminmax_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/packing_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/glsl_version.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/shader_compiler_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/shader_compiler.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/argminmax_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/avg_pool_backprop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/batchnorm_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/batchnorm_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/binaryop_complex_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/binaryop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/binaryop_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/clip_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/clip_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/complex_abs_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/concat_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/concat_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/conv_backprop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/conv_backprop_gpu_depthwise.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/conv_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/crop_and_resize_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/cumsum_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/decode_matrix_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/decode_matrix_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/depth_to_space_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/diag_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/encode_float_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/encode_float_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/encode_matrix_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/encode_matrix_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/fft_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/fill_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/gather_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/gather_nd_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/gpgpu_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/gpgpu_context.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/gpgpu_math.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/im2col_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/lrn_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/lrn_grad_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/lrn_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/max_pool_backprop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/mulmat_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/multinomial_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/onehot_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/pack_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/pad_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/pad_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/pool_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/reduce_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/reshape_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/resize_bilinear_backprop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/resize_bilinear_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/resize_bilinear_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/resize_nearest_neighbor_backprop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/resize_nearest_neighbor_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/reverse_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/reverse_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/scatter_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/segment_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/select_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/slice_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/slice_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/strided_slice_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/texture_manager.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/tile_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/ops/erf_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/selu_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/unaryop_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/unaryop_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/unpack_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/backend_webgl.ts","../node_modules/@tensorflow/tfjs-core/src/ops/segment_util.ts","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/lib/alea.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/lib/xor128.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/lib/xorwow.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/lib/xorshift7.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/lib/xor4096.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/lib/tychei.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/seedrandom.js","../node_modules/@tensorflow/tfjs-core/node_modules/seedrandom/index.js","../node_modules/@tensorflow/tfjs-core/src/ops/add_n.ts","../node_modules/@tensorflow/tfjs-core/src/ops/batchnorm_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/batchnorm.ts","../node_modules/@tensorflow/tfjs-core/src/ops/batchnorm2d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/batchnorm3d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/batchnorm4d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/broadcast_to.ts","../node_modules/@tensorflow/tfjs-core/src/ops/clone.ts","../node_modules/@tensorflow/tfjs-core/src/ops/logical_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/div_no_nan.ts","../node_modules/@tensorflow/tfjs-core/src/ops/tile.ts","../node_modules/@tensorflow/tfjs-core/src/ops/eye.ts","../node_modules/@tensorflow/tfjs-core/src/ops/multinomial.ts","../node_modules/@tensorflow/tfjs-core/src/ops/one_hot.ts","../node_modules/@tensorflow/tfjs-core/src/ops/pad.ts","../node_modules/@tensorflow/tfjs-core/src/ops/pad1d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/pad2d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/pad3d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/pad4d.ts","../node_modules/@tensorflow/tfjs-core/src/ops/rand.ts","../node_modules/@tensorflow/tfjs-core/src/test_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/rand_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/random_gamma.ts","../node_modules/@tensorflow/tfjs-core/src/ops/random_normal.ts","../node_modules/@tensorflow/tfjs-core/src/ops/random_uniform.ts","../node_modules/@tensorflow/tfjs-core/src/ops/square.ts","../node_modules/@tensorflow/tfjs-core/src/ops/squared_difference.ts","../node_modules/@tensorflow/tfjs-core/src/ops/truncated_normal.ts","../node_modules/@tensorflow/tfjs-core/src/ops/compare.ts","../node_modules/@tensorflow/tfjs-core/src/ops/segment_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/boolean_mask.ts","../node_modules/@tensorflow/tfjs-core/src/ops/conv.ts","../node_modules/@tensorflow/tfjs-core/src/ops/matmul.ts","../node_modules/@tensorflow/tfjs-core/src/ops/reverse.ts","../node_modules/@tensorflow/tfjs-core/src/ops/pool.ts","../node_modules/@tensorflow/tfjs-core/src/ops/slice.ts","../node_modules/@tensorflow/tfjs-core/src/ops/reduction_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/relu_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/lrn.ts","../node_modules/@tensorflow/tfjs-core/src/ops/norm.ts","../node_modules/@tensorflow/tfjs-core/src/ops/lstm.ts","../node_modules/@tensorflow/tfjs-core/src/ops/moving_average.ts","../node_modules/@tensorflow/tfjs-core/src/ops/strided_slice.ts","../node_modules/@tensorflow/tfjs-core/src/ops/topk.ts","../node_modules/@tensorflow/tfjs-core/src/ops/scatter_nd.ts","../node_modules/@tensorflow/tfjs-core/src/ops/spectral_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/sparse_to_dense.ts","../node_modules/@tensorflow/tfjs-core/src/ops/sparse_to_dense_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/gather_nd.ts","../node_modules/@tensorflow/tfjs-core/src/ops/diag.ts","../node_modules/@tensorflow/tfjs-core/src/ops/dropout.ts","../node_modules/@tensorflow/tfjs-core/src/ops/dropout_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/signal_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/loss_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/in_top_k.ts","../node_modules/@tensorflow/tfjs-core/src/ops/linalg_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/image_ops.ts","../node_modules/@tensorflow/tfjs-core/src/ops/fused_util.ts","../node_modules/@tensorflow/tfjs-core/src/ops/fused_ops.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/cpu_util.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/pool_utils.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/backend_cpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/utils/kernel_utils.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/Div_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/Div.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/MaxPoolWithArgmax.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/MaxPoolWithArgmax_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/NonMaxSuppressionV5.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/Square.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/Transpose_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/register_all_kernels.ts","../node_modules/@tensorflow/tfjs-core/src/backends/cpu/kernels/Transpose.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/Div.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/FromPixels_utils/from_pixels_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/FromPixels_utils/from_pixels_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/MaxPoolWithArgmax.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/transpose_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/transpose_packed_gpu.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/register_all_kernels.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/Transpose.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/FromPixels.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/Div_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/NonMaxSuppressionV5.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/Square.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/SquaredDifference.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/Transpose_impl.ts","../node_modules/@tensorflow/tfjs-core/src/backends/webgl/kernels/MaxPoolWithArgmax_impl.ts","../node_modules/@tensorflow/tfjs-core/src/register_all_gradients.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/Add_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/AddN_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/BroadcastTo_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/Div_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/FusedBatchNorm_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/Identity_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/OneHot_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/PadV2_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/Square_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/SquaredDifference_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/Tile_grad.ts","../node_modules/@tensorflow/tfjs-core/src/gradients/Transpose_grad.ts","../node_modules/@tensorflow/tfjs-core/src/platforms/platform_browser.ts","../node_modules/@tensorflow/tfjs-core/src/platforms/platform_node.ts","../node_modules/@tensorflow/tfjs-core/src/io/types.ts","../node_modules/@tensorflow/tfjs-core/src/io/io_utils.ts","../node_modules/@tensorflow/tfjs-core/src/io/router_registry.ts","../node_modules/@tensorflow/tfjs-core/src/io/model_management.ts","../node_modules/@tensorflow/tfjs-core/src/io/indexed_db.ts","../node_modules/@tensorflow/tfjs-core/src/io/local_storage.ts","../node_modules/@tensorflow/tfjs-core/src/io/browser_files.ts","../node_modules/@tensorflow/tfjs-core/src/io/progress.ts","../node_modules/@tensorflow/tfjs-core/src/io/weights_loader.ts","../node_modules/@tensorflow/tfjs-core/src/io/http.ts","../node_modules/@tensorflow/tfjs-core/src/io/passthrough.ts","../node_modules/@tensorflow/tfjs-core/src/ops/browser.ts","../node_modules/@tensorflow/tfjs-core/src/ops/confusion_matrix.ts","../node_modules/@tensorflow/tfjs-core/src/serialization.ts","../node_modules/@tensorflow/tfjs-core/src/version.ts","../node_modules/@tensorflow/tfjs-core/src/webgl.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/adadelta_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/adagrad_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/adam_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/adamax_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/sgd_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/momentum_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/rmsprop_optimizer.ts","../node_modules/@tensorflow/tfjs-core/src/optimizers/optimizer_constructors.ts","../node_modules/@tensorflow/tfjs-core/src/train.ts","../node_modules/@tensorflow/tfjs-core/src/browser_util.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/add.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/broadcast_to.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/div.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/div_no_nan.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/squared_difference.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/tile.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/one_hot.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/transpose.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/pad.ts","../node_modules/@tensorflow/tfjs-core/src/public/chained_ops/batchnorm.ts","../node_modules/@tensorflow/tfjs-core/src/index.ts","../node_modules/@tensorflow/tfjs-layers/src/backend/common.ts","../node_modules/@tensorflow/tfjs-layers/src/errors.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/generic_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/constraints.ts","../node_modules/@tensorflow/tfjs-layers/src/exports_constraints.ts","../node_modules/@tensorflow/tfjs-layers/src/keras_format/common.ts","../node_modules/@tensorflow/tfjs-layers/src/common.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/math_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/backend/tfjs_backend.ts","../node_modules/@tensorflow/tfjs-layers/src/keras_format/initializer_config.ts","../node_modules/@tensorflow/tfjs-layers/src/initializers.ts","../node_modules/@tensorflow/tfjs-layers/src/exports_initializers.ts","../node_modules/@tensorflow/tfjs-layers/src/backend/state.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/types_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/variable_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/variables.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/topology.ts","../node_modules/@tensorflow/tfjs-layers/src/base_callbacks.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/input_layer.ts","../node_modules/@tensorflow/tfjs-layers/src/logs.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/serialization.ts","../node_modules/@tensorflow/tfjs-layers/src/losses.ts","../node_modules/@tensorflow/tfjs-layers/src/metrics.ts","../node_modules/@tensorflow/tfjs-layers/src/optimizers.ts","../node_modules/@tensorflow/tfjs-layers/src/user_defined_metadata.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/layer_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/serialization_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/version.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/executor.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/container.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/training_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/training_dataset.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/training_tensors.ts","../node_modules/@tensorflow/tfjs-layers/src/engine/training.ts","../node_modules/@tensorflow/tfjs-layers/src/models.ts","../node_modules/@tensorflow/tfjs-layers/src/exports.ts","../node_modules/@tensorflow/tfjs-layers/src/activations.ts","../node_modules/@tensorflow/tfjs-layers/src/regularizers.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/advanced_activations.ts","../node_modules/@tensorflow/tfjs-layers/src/utils/conv_utils.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/convolutional.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/convolutional_depthwise.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/core.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/embeddings.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/merge.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/noise.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/normalization.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/padding.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/pooling.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/recurrent.ts","../node_modules/@tensorflow/tfjs-layers/src/layers/wrappers.ts","../node_modules/@tensorflow/tfjs-layers/src/exports_layers.ts","../node_modules/@tensorflow/tfjs-layers/src/exports_metrics.ts","../node_modules/@tensorflow/tfjs-layers/src/exports_regularizers.ts","../node_modules/@tensorflow/tfjs-layers/src/callbacks.ts","../node_modules/@tensorflow/tfjs-converter/src/data/compiled_api.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/custom_op/register.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/utils.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/arithmetic.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/basic_math.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/control.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/convolution.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/creation.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/dynamic.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/evaluation.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/graph.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/image.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/logical.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/matrices.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/normalization.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/reduction.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/slice_join.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/spectral.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/op_list/transformation.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/operation_mapper.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/custom_op/node_value_impl.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/arithmetic_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/basic_math_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/executor/tensor_array.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/control_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/convolution_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/creation_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/dynamic_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/evaluation_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/graph_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/image_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/logical_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/matrices_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/normalization_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/reduction_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/slice_join_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/spectral_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/executors/transformation_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/operations/operation_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/executor/execution_context.ts","../node_modules/@tensorflow/tfjs-converter/src/executor/model_analysis.ts","../node_modules/@tensorflow/tfjs-converter/src/executor/graph_executor.ts","../node_modules/@tensorflow/tfjs-converter/src/executor/graph_model.ts","../node_modules/@tensorflow/tfjs-converter/src/version.ts","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/lib/alea.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/lib/xor128.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/lib/xorwow.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/lib/xorshift7.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/lib/xor4096.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/lib/tychei.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/seedrandom.js","../node_modules/@tensorflow/tfjs-data/node_modules/seedrandom/index.js","../node_modules/@tensorflow/tfjs-data/src/util/deep_map.ts","../node_modules/@tensorflow/tfjs-data/src/util/deep_clone.ts","../node_modules/@tensorflow/tfjs-data/src/util/ring_buffer.ts","../node_modules/@tensorflow/tfjs-data/src/util/growing_ring_buffer.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/lazy_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/dataset.ts","../node_modules/@tensorflow/tfjs-data/src/datasets/text_line_dataset.ts","../node_modules/@tensorflow/tfjs-data/src/datasets/csv_dataset.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/microphone_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/webcam_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/datasource.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/string_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/byte_chunk_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/file_chunk_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/iterators/url_chunk_iterator.ts","../node_modules/@tensorflow/tfjs-data/src/util/source_util.ts","../node_modules/@tensorflow/tfjs-data/src/sources/file_data_source.ts","../node_modules/@tensorflow/tfjs-data/src/sources/url_data_source.ts","../node_modules/@tensorflow/tfjs-data/src/readers.ts","../node_modules/@tensorflow/tfjs-data/src/version.ts","../src/version.ts","../src/index.ts"],"sourcesContent":["/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Platform} from './platforms/platform';\n\n// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.\nconst TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';\n\ntype FlagValue = number|boolean;\nexport type Flags = {\n [featureName: string]: FlagValue\n};\nexport type FlagRegistryEntry = {\n evaluationFn: () => FlagValue;\n setHook?: (value: FlagValue) => void;\n};\n\n/**\n * The environment contains evaluated flags as well as the registered platform.\n * This is always used as a global singleton and can be retrieved with\n * `tf.env()`.\n */\n/** @doc {heading: 'Environment'} */\nexport class Environment {\n private flags: Flags = {};\n private flagRegistry: {[flagName: string]: FlagRegistryEntry} = {};\n\n private urlFlags: Flags = {};\n\n platformName: string;\n platform: Platform;\n\n // tslint:disable-next-line: no-any\n constructor(public global: any) {\n this.populateURLFlags();\n }\n\n setPlatform(platformName: string, platform: Platform) {\n if (this.platform != null) {\n console.warn(\n `Platform ${this.platformName} has already been set. ` +\n `Overwriting the platform with ${platform}.`);\n }\n this.platformName = platformName;\n this.platform = platform;\n }\n\n registerFlag(\n flagName: string, evaluationFn: () => FlagValue,\n setHook?: (value: FlagValue) => void) {\n this.flagRegistry[flagName] = {evaluationFn, setHook};\n\n // Override the flag value from the URL. This has to happen here because the\n // environment is initialized before flags get registered.\n if (this.urlFlags[flagName] != null) {\n const flagValue = this.urlFlags[flagName];\n console.warn(\n `Setting feature override from URL ${flagName}: ${flagValue}.`);\n this.set(flagName, flagValue);\n }\n }\n\n get(flagName: string): FlagValue {\n if (flagName in this.flags) {\n return this.flags[flagName];\n }\n\n this.flags[flagName] = this.evaluateFlag(flagName);\n\n return this.flags[flagName];\n }\n\n getNumber(flagName: string): number {\n return this.get(flagName) as number;\n }\n\n getBool(flagName: string): boolean {\n return this.get(flagName) as boolean;\n }\n\n getFlags(): Flags {\n return this.flags;\n }\n // For backwards compatibility.\n get features(): Flags {\n return this.flags;\n }\n\n set(flagName: string, value: FlagValue): void {\n if (this.flagRegistry[flagName] == null) {\n throw new Error(\n `Cannot set flag ${flagName} as it has not been registered.`);\n }\n this.flags[flagName] = value;\n if (this.flagRegistry[flagName].setHook != null) {\n this.flagRegistry[flagName].setHook(value);\n }\n }\n\n private evaluateFlag(flagName: string): FlagValue {\n if (this.flagRegistry[flagName] == null) {\n throw new Error(\n `Cannot evaluate flag '${flagName}': no evaluation function found.`);\n }\n return this.flagRegistry[flagName].evaluationFn();\n }\n\n setFlags(flags: Flags) {\n this.flags = Object.assign({}, flags);\n }\n\n reset() {\n this.flags = {};\n this.urlFlags = {};\n this.populateURLFlags();\n }\n\n private populateURLFlags(): void {\n if (typeof this.global === 'undefined' ||\n typeof this.global.location === 'undefined' ||\n typeof this.global.location.search === 'undefined') {\n return;\n }\n\n const urlParams = getQueryParams(this.global.location.search);\n if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {\n const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');\n keyValues.forEach(keyValue => {\n const [key, value] = keyValue.split(':') as [string, string];\n this.urlFlags[key] = parseValue(key, value);\n });\n }\n }\n}\n\nexport function getQueryParams(queryString: string): {[key: string]: string} {\n const params = {};\n queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {\n decodeParam(params, t[0], t[1]);\n return t.join('=');\n });\n return params;\n}\n\nfunction decodeParam(\n params: {[key: string]: string}, name: string, value?: string) {\n params[decodeURIComponent(name)] = decodeURIComponent(value || '');\n}\n\nfunction parseValue(flagName: string, value: string): FlagValue {\n value = value.toLowerCase();\n if (value === 'true' || value === 'false') {\n return value === 'true';\n } else if (`${+ value}` === value) {\n return +value;\n }\n throw new Error(\n `Could not parse value flag value ${value} for flag ${flagName}.`);\n}\n\n/**\n * Returns the current environment (a global singleton).\n *\n * The environment object contains the evaluated feature values as well as the\n * active platform.\n */\n/** @doc {heading: 'Environment'} */\nexport function env() {\n return ENV;\n}\n\nexport let ENV: Environment = null;\nexport function setEnvironmentGlobal(environment: Environment) {\n ENV = environment;\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {NamedGradientMap} from './tape';\nimport {Tensor} from './tensor';\nimport {DataType, RecursiveArray} from './types';\n\nconst kernelRegistry: Map = new Map();\nconst gradRegistry: Map = new Map();\n\nexport type DataId = object;\n\ntype AttributeValue =\n number|number[]|boolean|boolean[]|string|string[]|NamedAttrMap;\n\n/** These are extra non-tensor/primitive params passed to kernel functions. */\nexport type Attribute = AttributeValue|RecursiveArray;\n\n/** Specifies the code to run when executing a kernel. */\nexport type KernelFunc = (params: {\n inputs: NamedTensorInfoMap,\n backend: {},\n attrs?: NamedAttrMap,\n}) => TensorInfo|TensorInfo[];\n\n/** The function to run when computing a gradient during backprop. */\nexport type GradFunc =\n (dy: Tensor|Tensor[], saved: Tensor[], attrs: NamedAttrMap) =>\n NamedGradientMap;\n\n/** Function that gets called after the backend initializes. */\nexport type KernelSetupFunc = (backend: {}) => void;\n/** Function that gets called right before the backend is disposed. */\nexport type KernelDisposeFunc = KernelSetupFunc;\n\n/** Config object for registering a kernel in the global registry. */\nexport interface KernelConfig {\n kernelName: string;\n backendName: string;\n kernelFunc: KernelFunc;\n setupFunc?: KernelSetupFunc;\n disposeFunc?: KernelDisposeFunc;\n}\n\n/** Config object for registering a gradient in the global registry. */\nexport interface GradConfig {\n kernelName: string;\n inputsToSave?: string[];\n // When saveAllInputs is true, all inputs will be saved. Only use this flag\n // if inputs is an array of Tensors.\n saveAllInputs?: boolean;\n outputsToSave?: boolean[];\n gradFunc: GradFunc;\n}\n\n/** Holds metadata for a given tensor. */\nexport interface TensorInfo {\n dataId: DataId;\n shape: number[];\n dtype: DataType;\n}\n\nexport interface NamedTensorInfoMap {\n [name: string]: TensorInfo;\n}\n\nexport interface NamedAttrMap {\n [name: string]: Attribute;\n}\n\n/**\n * Returns the kernel function (code) associated with the provided names.\n *\n * @param kernelName The official name of the kernel.\n * @param backendName The official name of the backend.\n */\nexport function getKernel(\n kernelName: string, backendName: string): KernelConfig {\n const key = makeKey(kernelName, backendName);\n return kernelRegistry.get(key);\n}\n\n/**\n * Returns the registered gradient info associated with the provided kernel.\n * @param kernelName The official TF kernel name.\n */\nexport function getGradient(kernelName: string): GradConfig {\n return gradRegistry.get(kernelName);\n}\n\nexport function getKernelsForBackend(backendName: string): KernelConfig[] {\n const it = kernelRegistry.entries();\n const result: KernelConfig[] = [];\n\n while (true) {\n const {done, value} = it.next();\n if (done) {\n break;\n }\n const [key, config] = value;\n const [backend, ] = key.split('_');\n if (backend === backendName) {\n result.push(config);\n }\n }\n return result;\n}\n\n/**\n * Registers the function (forward pass) for the kernel in a global registry.\n *\n * @param config A config object with the following properties:\n * - `kernelName` The official name of the kernel.\n * - `backendName` The official name of the backend.\n * - `kernelFunc` The function to run during the forward pass of the kernel.\n * - `setupFunc` Optional. Gets called once, after the backend initializes.\n * - `disposeFunc` Optional. Gets called once, right before the backend is\n * disposed.\n */\nexport function registerKernel(config: KernelConfig) {\n const {kernelName, backendName} = config;\n const key = makeKey(kernelName, backendName);\n if (kernelRegistry.has(key)) {\n throw new Error(\n `The kernel '${kernelName}' for backend ` +\n `'${backendName}' is already registered`);\n }\n kernelRegistry.set(key, config);\n}\n\n/**\n * Registers a gradient function for a given kernel in the global registry,\n * to be used during the back-propagation of that kernel.\n *\n * @param config An object with the following properties:\n * - `kernelName` The name of the kernel that the gradient function is for.\n * - `gradFunc` The function to run during back-propagation.\n */\nexport function registerGradient(config: GradConfig) {\n const {kernelName} = config;\n if (gradRegistry.has(kernelName)) {\n console.warn(`Overriding the gradient for '${kernelName}'`);\n }\n gradRegistry.set(kernelName, config);\n}\n\n/**\n * Removes the kernel function from the registry.\n *\n * @param kernelName The official name of the kernel.\n * @param backendName The official name of the backend.\n *\n */\nexport function unregisterKernel(\n kernelName: string, backendName: string): void {\n const key = makeKey(kernelName, backendName);\n if (!kernelRegistry.has(key)) {\n throw new Error(\n `The kernel '${kernelName}' for backend ` +\n `'${backendName}' is not registered`);\n }\n kernelRegistry.delete(key);\n}\n\n/** Removes the registered gradient from the global registry. */\nexport function unregisterGradient(kernelName: string): void {\n if (!gradRegistry.has(kernelName)) {\n throw new Error(\n `The gradient '${kernelName}' for backend is not registered`);\n }\n gradRegistry.delete(kernelName);\n}\n\nfunction makeKey(kernelName: string, backendName: string) {\n return `${backendName}_${kernelName}`;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from './environment';\n\nimport {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types';\n\n/**\n * Shuffles the array in-place using Fisher-Yates algorithm.\n *\n * ```js\n * const a = [1, 2, 3, 4, 5];\n * tf.util.shuffle(a);\n * console.log(a);\n * ```\n *\n * @param array The array to shuffle in-place.\n */\n/** @doc {heading: 'Util', namespace: 'util'} */\n// tslint:disable-next-line:no-any\nexport function shuffle(array: any[]|Uint32Array|Int32Array|\n Float32Array): void {\n let counter = array.length;\n let temp = 0;\n let index = 0;\n // While there are elements in the array\n while (counter > 0) {\n // Pick a random index\n index = (Math.random() * counter) | 0;\n // Decrease counter by 1\n counter--;\n // And swap the last element with it\n temp = array[counter];\n array[counter] = array[index];\n array[index] = temp;\n }\n}\n\n/** Clamps a value to a specified range. */\nexport function clamp(min: number, x: number, max: number): number {\n return Math.max(min, Math.min(x, max));\n}\n\nexport function nearestLargerEven(val: number): number {\n return val % 2 === 0 ? val : val + 1;\n}\n\nexport function sum(arr: number[]): number {\n let sum = 0;\n for (let i = 0; i < arr.length; i++) {\n sum += arr[i];\n }\n return sum;\n}\n\n/**\n * Returns a sample from a uniform [a, b) distribution.\n *\n * @param a The minimum support (inclusive).\n * @param b The maximum support (exclusive).\n * @return A pseudorandom number on the half-open interval [a,b).\n */\nexport function randUniform(a: number, b: number) {\n const r = Math.random();\n return (b * r) + (1 - r) * a;\n}\n\n/** Returns the squared Euclidean distance between two vectors. */\nexport function distSquared(a: FlatVector, b: FlatVector): number {\n let result = 0;\n for (let i = 0; i < a.length; i++) {\n const diff = Number(a[i]) - Number(b[i]);\n result += diff * diff;\n }\n return result;\n}\n\n/**\n * Asserts that the expression is true. Otherwise throws an error with the\n * provided message.\n *\n * ```js\n * const x = 2;\n * tf.util.assert(x === 2, 'x is not 2');\n * ```\n *\n * @param expr The expression to assert (as a boolean).\n * @param msg A function that returns the message to report when throwing an\n * error. We use a function for performance reasons.\n */\n/** @doc {heading: 'Util', namespace: 'util'} */\nexport function assert(expr: boolean, msg: () => string) {\n if (!expr) {\n throw new Error(typeof msg === 'string' ? msg : msg());\n }\n}\n\nexport function assertShapesMatch(\n shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void {\n assert(\n arraysEqual(shapeA, shapeB),\n () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);\n}\n\nexport function assertNonNull(a: TensorLike): void {\n assert(\n a != null,\n () => `The input to the tensor constructor must be a non-null value.`);\n}\n\n// NOTE: We explicitly type out what T extends instead of any so that\n// util.flatten on a nested array of number doesn't try to infer T as a\n// number[][], causing us to explicitly type util.flatten().\n/**\n * Flattens an arbitrarily nested array.\n *\n * ```js\n * const a = [[1, 2], [3, 4], [5, [6, [7]]]];\n * const flat = tf.util.flatten(a);\n * console.log(flat);\n * ```\n *\n * @param arr The nested array to flatten.\n * @param result The destination array which holds the elements.\n * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults\n * to false.\n */\n/** @doc {heading: 'Util', namespace: 'util'} */\nexport function\nflatten|TypedArray>(\n arr: T|RecursiveArray, result: T[] = [], skipTypedArray = false): T[] {\n if (result == null) {\n result = [];\n }\n if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {\n for (let i = 0; i < arr.length; ++i) {\n flatten(arr[i], result, skipTypedArray);\n }\n } else {\n result.push(arr as T);\n }\n return result;\n}\n\n/**\n * Returns the size (number of elements) of the tensor given its shape.\n *\n * ```js\n * const shape = [3, 4, 2];\n * const size = tf.util.sizeFromShape(shape);\n * console.log(size);\n * ```\n */\n/** @doc {heading: 'Util', namespace: 'util'} */\nexport function sizeFromShape(shape: number[]): number {\n if (shape.length === 0) {\n // Scalar.\n return 1;\n }\n let size = shape[0];\n for (let i = 1; i < shape.length; i++) {\n size *= shape[i];\n }\n return size;\n}\n\nexport function isScalarShape(shape: number[]): boolean {\n return shape.length === 0;\n}\n\nexport function arraysEqual(n1: FlatVector, n2: FlatVector) {\n if (n1 === n2) {\n return true;\n }\n if (n1 == null || n2 == null) {\n return false;\n }\n\n if (n1.length !== n2.length) {\n return false;\n }\n for (let i = 0; i < n1.length; i++) {\n if (n1[i] !== n2[i]) {\n return false;\n }\n }\n return true;\n}\n\nexport function isInt(a: number): boolean {\n return a % 1 === 0;\n}\n\nexport function tanh(x: number): number {\n // tslint:disable-next-line:no-any\n if ((Math as any).tanh != null) {\n // tslint:disable-next-line:no-any\n return (Math as any).tanh(x);\n }\n if (x === Infinity) {\n return 1;\n } else if (x === -Infinity) {\n return -1;\n } else {\n const e2x = Math.exp(2 * x);\n return (e2x - 1) / (e2x + 1);\n }\n}\n\nexport function sizeToSquarishShape(size: number): [number, number] {\n const width = Math.ceil(Math.sqrt(size));\n return [width, Math.ceil(size / width)];\n}\n\n/**\n * Creates a new array with randomized indicies to a given quantity.\n *\n * ```js\n * const randomTen = tf.util.createShuffledIndices(10);\n * console.log(randomTen);\n * ```\n *\n * @param number Quantity of how many shuffled indicies to create.\n */\n/** @doc {heading: 'Util', namespace: 'util'} */\nexport function createShuffledIndices(n: number): Uint32Array {\n const shuffledIndices = new Uint32Array(n);\n for (let i = 0; i < n; ++i) {\n shuffledIndices[i] = i;\n }\n shuffle(shuffledIndices);\n return shuffledIndices;\n}\n\nexport function rightPad(a: string, size: number): string {\n if (size <= a.length) {\n return a;\n }\n return a + ' '.repeat(size - a.length);\n}\n\nexport function repeatedTry(\n checkFn: () => boolean, delayFn = (counter: number) => 0,\n maxCounter?: number): Promise {\n return new Promise((resolve, reject) => {\n let tryCount = 0;\n\n const tryFn = () => {\n if (checkFn()) {\n resolve();\n return;\n }\n\n tryCount++;\n\n const nextBackoff = delayFn(tryCount);\n\n if (maxCounter != null && tryCount >= maxCounter) {\n reject();\n return;\n }\n setTimeout(tryFn, nextBackoff);\n };\n\n tryFn();\n });\n}\n\n/**\n * Given the full size of the array and a shape that may contain -1 as the\n * implicit dimension, returns the inferred shape where -1 is replaced.\n * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].\n *\n * @param shape The shape, which may contain -1 in some dimension.\n * @param size The full size (number of elements) of the array.\n * @return The inferred shape where -1 is replaced with the inferred size.\n */\nexport function inferFromImplicitShape(\n shape: number[], size: number): number[] {\n let shapeProd = 1;\n let implicitIdx = -1;\n\n for (let i = 0; i < shape.length; ++i) {\n if (shape[i] >= 0) {\n shapeProd *= shape[i];\n } else if (shape[i] === -1) {\n if (implicitIdx !== -1) {\n throw Error(\n `Shapes can only have 1 implicit size. ` +\n `Found -1 at dim ${implicitIdx} and dim ${i}`);\n }\n implicitIdx = i;\n } else if (shape[i] < 0) {\n throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);\n }\n }\n\n if (implicitIdx === -1) {\n if (size > 0 && size !== shapeProd) {\n throw Error(`Size(${size}) must match the product of shape ${shape}`);\n }\n return shape;\n }\n\n if (shapeProd === 0) {\n throw Error(\n `Cannot infer the missing size in [${shape}] when ` +\n `there are 0 elements`);\n }\n if (size % shapeProd !== 0) {\n throw Error(\n `The implicit shape can't be a fractional number. ` +\n `Got ${size} / ${shapeProd}`);\n }\n\n const newShape = shape.slice();\n newShape[implicitIdx] = size / shapeProd;\n return newShape;\n}\n\nexport function parseAxisParam(\n axis: number|number[], shape: number[]): number[] {\n const rank = shape.length;\n\n // Normalize input\n axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);\n\n // Check for valid range\n assert(\n axis.every(ax => ax >= -rank && ax < rank),\n () =>\n `All values in axis param must be in range [-${rank}, ${rank}) but ` +\n `got axis ${axis}`);\n\n // Check for only integers\n assert(\n axis.every(ax => isInt(ax)),\n () => `All values in axis param must be integers but ` +\n `got axis ${axis}`);\n\n // Handle negative axis.\n return axis.map(a => a < 0 ? rank + a : a);\n}\n\n/** Reduces the shape by removing all dimensions of shape 1. */\nexport function squeezeShape(shape: number[], axis?: number[]):\n {newShape: number[], keptDims: number[]} {\n const newShape: number[] = [];\n const keptDims: number[] = [];\n const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;\n const axes = (axis == null || isEmptyArray) ?\n null :\n parseAxisParam(axis, shape).sort();\n let j = 0;\n for (let i = 0; i < shape.length; ++i) {\n if (axes != null) {\n if (axes[j] === i && shape[i] !== 1) {\n throw new Error(\n `Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);\n }\n if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {\n newShape.push(shape[i]);\n keptDims.push(i);\n }\n if (axes[j] <= i) {\n j++;\n }\n }\n if (shape[i] !== 1) {\n newShape.push(shape[i]);\n keptDims.push(i);\n }\n }\n return {newShape, keptDims};\n}\n\nexport function getTypedArrayFromDType(\n dtype: D, size: number): DataTypeMap[D] {\n let values = null;\n if (dtype == null || dtype === 'float32') {\n values = new Float32Array(size);\n } else if (dtype === 'int32') {\n values = new Int32Array(size);\n } else if (dtype === 'bool') {\n values = new Uint8Array(size);\n } else {\n throw new Error(`Unknown data type ${dtype}`);\n }\n return values as DataTypeMap[D];\n}\n\nexport function getArrayFromDType(\n dtype: D, size: number): DataTypeMap[D] {\n let values = null;\n if (dtype == null || dtype === 'float32') {\n values = new Float32Array(size);\n } else if (dtype === 'int32') {\n values = new Int32Array(size);\n } else if (dtype === 'bool') {\n values = new Uint8Array(size);\n } else if (dtype === 'string') {\n values = new Array<'string'>(size);\n } else {\n throw new Error(`Unknown data type ${dtype}`);\n }\n return values as DataTypeMap[D];\n}\n\nexport function checkConversionForErrors(\n vals: DataTypeMap[D]|number[], dtype: D): void {\n for (let i = 0; i < vals.length; i++) {\n const num = vals[i] as number;\n if (isNaN(num) || !isFinite(num)) {\n throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);\n }\n }\n}\n\n/** Returns true if the dtype is valid. */\nexport function isValidDtype(dtype: DataType): boolean {\n return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||\n dtype === 'int32' || dtype === 'string';\n}\n\n/**\n * Returns true if the new type can't encode the old type without loss of\n * precision.\n */\nexport function hasEncodingLoss(oldType: DataType, newType: DataType): boolean {\n if (newType === 'complex64') {\n return false;\n }\n if (newType === 'float32' && oldType !== 'complex64') {\n return false;\n }\n if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {\n return false;\n }\n if (newType === 'bool' && oldType === 'bool') {\n return false;\n }\n return true;\n}\n\nexport function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array {\n return a instanceof Float32Array || a instanceof Int32Array ||\n a instanceof Uint8Array;\n}\n\nexport function bytesPerElement(dtype: DataType): number {\n if (dtype === 'float32' || dtype === 'int32') {\n return 4;\n } else if (dtype === 'complex64') {\n return 8;\n } else if (dtype === 'bool') {\n return 1;\n } else {\n throw new Error(`Unknown dtype ${dtype}`);\n }\n}\n\n/**\n * Returns the approximate number of bytes allocated in the string array - 2\n * bytes per character. Computing the exact bytes for a native string in JS is\n * not possible since it depends on the encoding of the html page that serves\n * the website.\n */\nexport function bytesFromStringArray(arr: Uint8Array[]): number {\n if (arr == null) {\n return 0;\n }\n let bytes = 0;\n arr.forEach(x => bytes += x.length);\n return bytes;\n}\n\n/** Returns true if the value is a string. */\nexport function isString(value: {}): value is string {\n return typeof value === 'string' || value instanceof String;\n}\n\nexport function isBoolean(value: {}): boolean {\n return typeof value === 'boolean';\n}\n\nexport function isNumber(value: {}): boolean {\n return typeof value === 'number';\n}\n\nexport function inferDtype(values: TensorLike): DataType {\n if (Array.isArray(values)) {\n return inferDtype(values[0]);\n }\n if (values instanceof Float32Array) {\n return 'float32';\n } else if (values instanceof Int32Array || values instanceof Uint8Array) {\n return 'int32';\n } else if (isNumber(values)) {\n return 'float32';\n } else if (isString(values)) {\n return 'string';\n } else if (isBoolean(values)) {\n return 'bool';\n }\n return 'float32';\n}\n\nexport function isFunction(f: Function) {\n return !!(f && f.constructor && f.call && f.apply);\n}\n\nexport function nearestDivisor(size: number, start: number): number {\n for (let i = start; i < size; ++i) {\n if (size % i === 0) {\n return i;\n }\n }\n return size;\n}\n\nexport function computeStrides(shape: number[]): number[] {\n const rank = shape.length;\n if (rank < 2) {\n return [];\n }\n\n // Last dimension has implicit stride of 1, thus having D-1 (instead of D)\n // strides.\n const strides = new Array(rank - 1);\n strides[rank - 2] = shape[rank - 1];\n for (let i = rank - 3; i >= 0; --i) {\n strides[i] = strides[i + 1] * shape[i + 1];\n }\n return strides;\n}\n\nexport function toTypedArray(\n a: TensorLike, dtype: DataType, debugMode: boolean): TypedArray {\n if (dtype === 'string') {\n throw new Error('Cannot convert a string[] to a TypedArray');\n }\n if (Array.isArray(a)) {\n a = flatten(a);\n }\n if (debugMode) {\n checkConversionForErrors(a as number[], dtype);\n }\n if (noConversionNeeded(a, dtype)) {\n return a as TypedArray;\n }\n if (dtype == null || dtype === 'float32' || dtype === 'complex64') {\n return new Float32Array(a as number[]);\n } else if (dtype === 'int32') {\n return new Int32Array(a as number[]);\n } else if (dtype === 'bool') {\n const bool = new Uint8Array((a as number[]).length);\n for (let i = 0; i < bool.length; ++i) {\n if (Math.round((a as number[])[i]) !== 0) {\n bool[i] = 1;\n }\n }\n return bool;\n } else {\n throw new Error(`Unknown data type ${dtype}`);\n }\n}\n\nfunction createNestedArray(offset: number, shape: number[], a: TypedArray) {\n const ret = new Array();\n if (shape.length === 1) {\n const d = shape[0];\n for (let i = 0; i < d; i++) {\n ret[i] = a[offset + i];\n }\n } else {\n const d = shape[0];\n const rest = shape.slice(1);\n const len = rest.reduce((acc, c) => acc * c);\n for (let i = 0; i < d; i++) {\n ret[i] = createNestedArray(offset + i * len, rest, a);\n }\n }\n return ret;\n}\n\n// Provide a nested array of TypedArray in given shape.\nexport function toNestedArray(shape: number[], a: TypedArray) {\n if (shape.length === 0) {\n // Scalar type should return a single number.\n return a[0];\n }\n const size = shape.reduce((acc, c) => acc * c);\n if (size === 0) {\n // A tensor with shape zero should be turned into empty list.\n return [];\n }\n if (size !== a.length) {\n throw new Error(`[${shape}] does not match the input size.`);\n }\n\n return createNestedArray(0, shape, a);\n}\n\nfunction noConversionNeeded(a: TensorLike, dtype: DataType): boolean {\n return (a instanceof Float32Array && dtype === 'float32') ||\n (a instanceof Int32Array && dtype === 'int32') ||\n (a instanceof Uint8Array && dtype === 'bool');\n}\n\nexport function makeOnesTypedArray(\n size: number, dtype: D): DataTypeMap[D] {\n const array = makeZerosTypedArray(size, dtype);\n for (let i = 0; i < array.length; i++) {\n array[i] = 1;\n }\n return array;\n}\n\nexport function makeZerosTypedArray(\n size: number, dtype: D): DataTypeMap[D] {\n if (dtype == null || dtype === 'float32' || dtype === 'complex64') {\n return new Float32Array(size) as DataTypeMap[D];\n } else if (dtype === 'int32') {\n return new Int32Array(size) as DataTypeMap[D];\n } else if (dtype === 'bool') {\n return new Uint8Array(size) as DataTypeMap[D];\n } else {\n throw new Error(`Unknown data type ${dtype}`);\n }\n}\n\n/**\n * Returns the current high-resolution time in milliseconds relative to an\n * arbitrary time in the past. It works across different platforms (node.js,\n * browsers).\n *\n * ```js\n * console.log(tf.util.now());\n * ```\n */\n/** @doc {heading: 'Util', namespace: 'util'} */\nexport function now(): number {\n return env().platform.now();\n}\n\nexport function assertNonNegativeIntegerDimensions(shape: number[]) {\n shape.forEach(dimSize => {\n assert(\n Number.isInteger(dimSize) && dimSize >= 0,\n () =>\n `Tensor must have a shape comprised of positive integers but got ` +\n `shape [${shape}].`);\n });\n}\n\n/**\n * Returns a platform-specific implementation of\n * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).\n *\n * If `fetch` is defined on the global object (`window`, `process`, etc.),\n * `tf.util.fetch` returns that function.\n *\n * If not, `tf.util.fetch` returns a platform-specific solution.\n *\n * ```js\n * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');\n * // handle response\n * ```\n */\n/** @doc {heading: 'Util'} */\nexport function fetch(\n path: string, requestInits?: RequestInit): Promise {\n return env().platform.fetch(path, requestInits);\n}\n\n/**\n * Encodes the provided string into bytes using the provided encoding scheme.\n *\n * @param s The string to encode.\n * @param encoding The encoding scheme. Defaults to utf-8.\n *\n */\n/** @doc {heading: 'Util'} */\nexport function encodeString(s: string, encoding = 'utf-8'): Uint8Array {\n encoding = encoding || 'utf-8';\n return env().platform.encode(s, encoding);\n}\n\n/**\n * Decodes the provided bytes into a string using the provided encoding scheme.\n * @param bytes The bytes to decode.\n *\n * @param encoding The encoding scheme. Defaults to utf-8.\n */\n/** @doc {heading: 'Util'} */\nexport function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string {\n encoding = encoding || 'utf-8';\n return env().platform.decode(bytes, encoding);\n}\n\n/**\n * Computes flat index for a given location (multidimentionsal index) in a\n * Tensor/multidimensional array.\n *\n * @param locs Location in the tensor.\n * @param rank Rank of the tensor.\n * @param strides Tensor strides.\n */\nexport function locToIndex(\n locs: number[], rank: number, strides: number[]): number {\n if (rank === 0) {\n return 0;\n } else if (rank === 1) {\n return locs[0];\n }\n let index = locs[locs.length - 1];\n for (let i = 0; i < locs.length - 1; ++i) {\n index += strides[i] * locs[i];\n }\n return index;\n}\n\n/**\n * Computes the location (multidimensional index) in a tensor/multidimentional\n * array for a given flat index.\n *\n * @param index Index in flat array.\n * @param rank Rank of tensor.\n * @param strides Strides of tensor.\n */\nexport function indexToLoc(\n index: number, rank: number, strides: number[]): number[] {\n if (rank === 0) {\n return [];\n } else if (rank === 1) {\n return [index];\n }\n const locs: number[] = new Array(rank);\n for (let i = 0; i < locs.length - 1; ++i) {\n locs[i] = Math.floor(index / strides[i]);\n index -= locs[i] * strides[i];\n }\n locs[locs.length - 1] = index;\n return locs;\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {BackendTimer} from './backends/backend';\nimport {Tensor} from './tensor';\nimport {NamedTensorMap} from './tensor_types';\nimport {DataType, DataTypeMap, TypedArray} from './types';\nimport * as util from './util';\n\nexport class Profiler {\n constructor(private backendTimer: BackendTimer, private logger?: Logger) {\n if (logger == null) {\n this.logger = new Logger();\n }\n }\n\n profileKernel(kernelName: string, inputs: NamedTensorMap, f: () => Tensor[]):\n Tensor[] {\n let outputs: Tensor[];\n const holdResultWrapperFn = () => {\n outputs = f();\n };\n const timer = this.backendTimer.time(holdResultWrapperFn);\n\n outputs.forEach(r => {\n // Dangling promise here because we don't want to propagate up\n // asynchronicity.\n r.data().then(vals => {\n checkComputationForErrors(vals, r.dtype, kernelName);\n\n timer.then(timing => {\n let extraInfo = '';\n if (timing.getExtraProfileInfo != null) {\n extraInfo = timing.getExtraProfileInfo();\n }\n\n this.logger.logKernelProfile(\n kernelName, r, vals, timing.kernelMs, inputs, extraInfo);\n });\n });\n });\n\n return outputs;\n }\n}\n\nexport function checkComputationForErrors(\n vals: DataTypeMap[D], dtype: D, kernelName: string): boolean {\n if (dtype !== 'float32') {\n // Only floating point computations will generate NaN values\n return false;\n }\n for (let i = 0; i < vals.length; i++) {\n const num = vals[i] as number;\n if (isNaN(num) || !isFinite(num)) {\n // Throwing custom exception so behavior is testable.\n console.warn(`Found ${num} in the result of '${kernelName}'`);\n return true;\n }\n }\n return false;\n}\n\nexport class Logger {\n logKernelProfile(\n name: string, result: Tensor, vals: TypedArray,\n timeMs: number|{error: string}, inputs: NamedTensorMap,\n extraInfo?: string) {\n const time = typeof timeMs === 'number' ? util.rightPad(`${timeMs}ms`, 9) :\n timeMs['error'];\n const paddedName = util.rightPad(name, 25);\n const rank = result.rank;\n const size = result.size;\n const shape = util.rightPad(result.shape.toString(), 14);\n let inputShapesDescription = '';\n\n for (const name in inputs) {\n const input = inputs[name];\n // The input might be a non-tensor (e.g HTMLImageElement), in which case\n // we claim the output shape as input shape.\n const inputShape = input.shape || result.shape;\n const inputRank = inputShape.length;\n inputShapesDescription +=\n `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;\n }\n\n console.log(\n `%c${paddedName}\\t%c${time}\\t%c${rank}D ${shape}\\t%c${size}\\t%c${\n inputShapesDescription}\\t%c${extraInfo}`,\n 'font-weight:bold', 'color:red', 'color:blue', 'color: orange',\n 'color: green', 'color: steelblue');\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, TypedArray} from './types';\nimport {computeStrides, isString, rightPad, sizeFromShape} from './util';\n\n// Maximum number of values before we decide to show ellipsis.\nconst FORMAT_LIMIT_NUM_VALS = 20;\n// Number of first and last values to show when displaying a, b,...,y, z.\nconst FORMAT_NUM_FIRST_LAST_VALS = 3;\n// Number of significant digits to show.\nconst FORMAT_NUM_SIG_DIGITS = 7;\n\nexport function tensorToString(\n vals: TypedArray|string[], shape: number[], dtype: DataType,\n verbose: boolean) {\n const strides = computeStrides(shape);\n const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);\n const rank = shape.length;\n const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);\n const lines = ['Tensor'];\n if (verbose) {\n lines.push(` dtype: ${dtype}`);\n lines.push(` rank: ${rank}`);\n lines.push(` shape: [${shape}]`);\n lines.push(` values:`);\n }\n lines.push(valsLines.map(l => ' ' + l).join('\\n'));\n return lines.join('\\n');\n}\n\nfunction computeMaxSizePerColumn(\n vals: TypedArray|string[], shape: number[], dtype: DataType,\n strides: number[]): number[] {\n const n = sizeFromShape(shape);\n const numCols = strides[strides.length - 1];\n const padPerCol = new Array(numCols).fill(0);\n const rank = shape.length;\n const valuesOrTuples =\n dtype === 'complex64' ? createComplexTuples(vals) : vals;\n\n if (rank > 1) {\n for (let row = 0; row < n / numCols; row++) {\n const offset = row * numCols;\n for (let j = 0; j < numCols; j++) {\n padPerCol[j] = Math.max(\n padPerCol[j],\n valToString(valuesOrTuples[offset + j], 0, dtype).length);\n }\n }\n }\n return padPerCol;\n}\n\nfunction valToString(\n val: number|string|[number, number], pad: number, dtype: DataType) {\n let valStr: string;\n if (Array.isArray(val)) {\n valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +\n `${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;\n } else if (isString(val)) {\n valStr = `'${val}'`;\n } else if (dtype === 'bool') {\n valStr = boolNumToString(val);\n } else {\n valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();\n }\n\n return rightPad(valStr, pad);\n}\n\nfunction boolNumToString(v: number): string {\n return v === 0 ? 'false' : 'true';\n}\n\nfunction subTensorToString(\n vals: TypedArray|string[], shape: number[], dtype: DataType,\n strides: number[], padPerCol: number[], isLast = true): string[] {\n const storagePerElement = dtype === 'complex64' ? 2 : 1;\n\n const size = shape[0];\n const rank = shape.length;\n if (rank === 0) {\n if (dtype === 'complex64') {\n const complexTuple = createComplexTuples(vals);\n return [valToString(complexTuple[0], 0, dtype)];\n }\n if (dtype === 'bool') {\n return [boolNumToString(vals[0] as number)];\n }\n return [vals[0].toString()];\n }\n\n if (rank === 1) {\n if (size > FORMAT_LIMIT_NUM_VALS) {\n const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;\n\n let firstVals = Array.from(\n vals.slice(0, firstValsSize));\n let lastVals = Array.from(vals.slice(\n (size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement,\n size * storagePerElement));\n if (dtype === 'complex64') {\n firstVals = createComplexTuples(firstVals);\n lastVals = createComplexTuples(lastVals);\n }\n return [\n '[' +\n firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))\n .join(', ') +\n ', ..., ' +\n lastVals\n .map(\n (x, i) => valToString(\n x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))\n .join(', ') +\n ']'\n ];\n }\n const displayVals: Array =\n dtype === 'complex64' ? createComplexTuples(vals) :\n Array.from(vals);\n\n return [\n '[' +\n displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))\n .join(', ') +\n ']'\n ];\n }\n\n // The array is rank 2 or more.\n const subshape = shape.slice(1);\n const substrides = strides.slice(1);\n const stride = strides[0] * storagePerElement;\n const lines: string[] = [];\n if (size > FORMAT_LIMIT_NUM_VALS) {\n for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {\n const start = i * stride;\n const end = start + stride;\n lines.push(...subTensorToString(\n vals.slice(start, end), subshape, dtype, substrides, padPerCol,\n false /* isLast */));\n }\n lines.push('...');\n for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {\n const start = i * stride;\n const end = start + stride;\n lines.push(...subTensorToString(\n vals.slice(start, end), subshape, dtype, substrides, padPerCol,\n i === size - 1 /* isLast */));\n }\n } else {\n for (let i = 0; i < size; i++) {\n const start = i * stride;\n const end = start + stride;\n lines.push(...subTensorToString(\n vals.slice(start, end), subshape, dtype, substrides, padPerCol,\n i === size - 1 /* isLast */));\n }\n }\n const sep = rank === 2 ? ',' : '';\n lines[0] = '[' + lines[0] + sep;\n for (let i = 1; i < lines.length - 1; i++) {\n lines[i] = ' ' + lines[i] + sep;\n }\n let newLineSep = ',\\n';\n for (let i = 2; i < rank; i++) {\n newLineSep += '\\n';\n }\n lines[lines.length - 1] =\n ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);\n return lines;\n}\n\nfunction createComplexTuples(vals: Array<{}>|\n TypedArray): Array<[number, number]> {\n const complexTuples: Array<[number, number]> = [];\n for (let i = 0; i < vals.length; i += 2) {\n complexTuples.push([vals[i], vals[i + 1]] as [number, number]);\n }\n return complexTuples;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {tensorToString} from './tensor_format';\nimport {ArrayMap, BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D, TypedArray} from './types';\nimport * as util from './util';\nimport {computeStrides, toNestedArray} from './util';\n\nexport interface TensorData {\n dataId?: DataId;\n values?: DataTypeMap[D];\n}\n\n// This interface mimics KernelBackend (in backend.ts), which would create a\n// circular dependency if imported.\nexport interface Backend {}\n\n/**\n * A mutable object, similar to `tf.Tensor`, that allows users to set values\n * at locations before converting to an immutable `tf.Tensor`.\n *\n * See `tf.buffer` for creating a tensor buffer.\n */\n/** @doc {heading: 'Tensors', subheading: 'Classes'} */\nexport class TensorBuffer {\n size: number;\n shape: ShapeMap[R];\n strides: number[];\n values: DataTypeMap[D];\n\n constructor(shape: ShapeMap[R], public dtype: D, values?: DataTypeMap[D]) {\n this.shape = shape.slice() as ShapeMap[R];\n this.size = util.sizeFromShape(shape);\n\n if (values != null) {\n const n = values.length;\n util.assert(\n n === this.size,\n () => `Length of values '${n}' does not match the size ` +\n `inferred by the shape '${this.size}'.`);\n }\n if (dtype === 'complex64') {\n throw new Error(\n `complex64 dtype TensorBuffers are not supported. Please create ` +\n `a TensorBuffer for the real and imaginary parts separately and ` +\n `call tf.complex(real, imag).`);\n }\n this.values = values || util.getArrayFromDType(dtype, this.size);\n this.strides = computeStrides(shape);\n }\n\n /**\n * Sets a value in the buffer at a given location.\n *\n * @param value The value to set.\n * @param locs The location indices.\n */\n /** @doc {heading: 'Tensors', subheading: 'Creation'} */\n set(value: SingleValueMap[D], ...locs: number[]): void {\n if (locs.length === 0) {\n locs = [0];\n }\n util.assert(\n locs.length === this.rank,\n () => `The number of provided coordinates (${locs.length}) must ` +\n `match the rank (${this.rank})`);\n\n const index = this.locToIndex(locs);\n this.values[index] = value as number;\n }\n\n /**\n * Returns the value in the buffer at the provided location.\n *\n * @param locs The location indices.\n */\n /** @doc {heading: 'Tensors', subheading: 'Creation'} */\n get(...locs: number[]): SingleValueMap[D] {\n if (locs.length === 0) {\n locs = [0];\n }\n let i = 0;\n for (const loc of locs) {\n if (loc < 0 || loc >= this.shape[i]) {\n const msg = `Requested out of range element at ${locs}. ` +\n ` Buffer shape=${this.shape}`;\n throw new Error(msg);\n }\n i++;\n }\n let index = locs[locs.length - 1];\n for (let i = 0; i < locs.length - 1; ++i) {\n index += this.strides[i] * locs[i];\n }\n return this.values[index] as SingleValueMap[D];\n }\n\n locToIndex(locs: number[]): number {\n if (this.rank === 0) {\n return 0;\n } else if (this.rank === 1) {\n return locs[0];\n }\n let index = locs[locs.length - 1];\n for (let i = 0; i < locs.length - 1; ++i) {\n index += this.strides[i] * locs[i];\n }\n return index;\n }\n\n indexToLoc(index: number): number[] {\n if (this.rank === 0) {\n return [];\n } else if (this.rank === 1) {\n return [index];\n }\n const locs: number[] = new Array(this.shape.length);\n for (let i = 0; i < locs.length - 1; ++i) {\n locs[i] = Math.floor(index / this.strides[i]);\n index -= locs[i] * this.strides[i];\n }\n locs[locs.length - 1] = index;\n return locs;\n }\n\n get rank() {\n return this.shape.length;\n }\n\n /**\n * Creates an immutable `tf.Tensor` object from the buffer.\n */\n /** @doc {heading: 'Tensors', subheading: 'Creation'} */\n toTensor(): Tensor {\n return trackerFn().makeTensor(this.values, this.shape, this.dtype) as\n Tensor;\n }\n}\n\nexport interface TensorTracker {\n makeTensor(\n values: DataValues, shape: number[], dtype: DataType,\n backend?: Backend): Tensor;\n makeVariable(\n initialValue: Tensor, trainable?: boolean, name?: string,\n dtype?: DataType): Variable;\n incRef(a: Tensor, backend: Backend): void;\n disposeTensor(t: Tensor): void;\n disposeVariable(v: Variable): void;\n read(dataId: DataId): Promise;\n readSync(dataId: DataId): BackendValues;\n}\n\n/**\n * The Tensor class calls into this handler to delegate chaining operations.\n */\nexport interface OpHandler {\n cast(x: T, dtype: DataType): T;\n buffer(\n shape: ShapeMap[R], dtype: D,\n values?: DataTypeMap[D]): TensorBuffer;\n print(x: T, verbose: boolean): void;\n reshape(x: Tensor, shape: ShapeMap[R2]): Tensor;\n expandDims(x: Tensor, axis: number): Tensor;\n cumsum(\n x: Tensor, axis: number, exclusive: boolean, reverse: boolean): T;\n squeeze(x: Tensor, axis?: number[]): T;\n clone(x: T): T;\n gather(x: T, indices: Tensor|TensorLike, axis: number): T;\n matMul(\n a: T, b: T|TensorLike, transposeA: boolean, transposeB: boolean): T;\n dot(t1: Tensor, t2: Tensor|TensorLike): Tensor;\n norm(\n x: Tensor, ord: number|'euclidean'|'fro', axis: number|number[],\n keepDims: boolean): Tensor;\n slice>(\n x: T, begin: number|number[], size?: number|number[]): T;\n split(\n x: T, numOrSizeSplits: number[]|number, axis?: number): T[];\n reverse(x: T, axis?: number|number[]): T;\n concat(tensors: Array, axis: number): T;\n stack(tensors: Array, axis: number): Tensor;\n unstack(value: T, axis: number): Tensor[];\n all(x: Tensor, axis: number|number[], keepDims: boolean): T;\n any(x: Tensor, axis: number|number[], keepDims: boolean): T;\n logSumExp(\n x: Tensor, axis: number|number[], keepDims: boolean): T;\n sum(x: Tensor, axis: number|number[], keepDims: boolean): T;\n prod(x: Tensor, axis: number|number[], keepDims: boolean):\n T;\n mean(x: Tensor, axis: number|number[], keepDims: boolean):\n T;\n min(x: Tensor, axis: number|number[], keepDims: boolean): T;\n max(x: Tensor, axis: number|number[], keepDims: boolean): T;\n argMin(x: Tensor, axis: number): T;\n argMax(x: Tensor, axis: number): T;\n addStrict(a: T, b: T|TensorLike): T;\n atan2(a: Tensor, b: Tensor|TensorLike): T;\n sub(a: Tensor, b: Tensor|TensorLike): T;\n subStrict(a: T, b: T|TensorLike): T;\n pow(base: T, exp: Tensor|TensorLike): T;\n powStrict(base: T, exp: Tensor|TensorLike): T;\n mul(a: Tensor, b: Tensor|TensorLike): T;\n mulStrict(a: T, b: T|TensorLike): T;\n floorDiv(a: Tensor, b: Tensor|TensorLike): T;\n divStrict(a: T, b: T|TensorLike): T;\n mod(a: Tensor, b: Tensor|TensorLike): T;\n modStrict(a: T, b: T|TensorLike): T;\n minimum(a: Tensor, b: Tensor|TensorLike): T;\n minimumStrict(a: T, b: T|TensorLike): T;\n maximum(a: Tensor, b: Tensor|TensorLike): T;\n maximumStrict(a: T, b: T|TensorLike): T;\n squaredDifferenceStrict(a: T, b: T|TensorLike): T;\n logicalNot(x: T): T;\n logicalAnd(a: Tensor, b: Tensor|TensorLike): T;\n logicalOr(a: Tensor, b: Tensor|TensorLike): T;\n logicalXor(a: Tensor, b: Tensor|TensorLike): T;\n where(condition: Tensor|TensorLike, a: T, b: T|TensorLike):\n T;\n notEqual(a: Tensor, b: Tensor|TensorLike): T;\n notEqualStrict(a: T, b: T|TensorLike): T;\n less(a: Tensor, b: Tensor|TensorLike): T;\n lessStrict(a: T, b: T|TensorLike): T;\n equal(a: Tensor, b: Tensor|TensorLike): T;\n equalStrict(a: T, b: T|TensorLike): T;\n lessEqual(a: Tensor, b: Tensor|TensorLike): T;\n lessEqualStrict(a: T, b: T|TensorLike): T;\n greater(a: Tensor, b: Tensor|TensorLike): T;\n greaterStrict(a: T, b: T|TensorLike): T;\n greaterEqual(a: Tensor, b: Tensor|TensorLike): T;\n greaterEqualStrict(a: T, b: T|TensorLike): T;\n neg(x: T): T;\n ceil(x: T): T;\n floor(x: T): T;\n sign(x: T): T;\n isNaN(x: T): T;\n isInf(x: T): T;\n isFinite(x: T): T;\n round(x: T): T;\n exp(x: T): T;\n expm1(x: T): T;\n log(x: T): T;\n log1p(x: T): T;\n sqrt(x: T): T;\n rsqrt(x: T): T;\n square(x: T): T;\n reciprocal(x: T): T;\n abs(x: T): T;\n clipByValue(\n x: T, clipValueMin: number, clipValueMax: number): T;\n sigmoid(x: T): T;\n logSigmoid(x: T): T;\n softplus(x: T): T;\n zerosLike(x: T): T;\n onesLike(x: T): T;\n sin(x: T): T;\n cos(x: T): T;\n tan(x: T): T;\n asin(x: T): T;\n acos(x: T): T;\n atan(x: T): T;\n sinh(x: T): T;\n cosh(x: T): T;\n tanh(x: T): T;\n asinh(x: T): T;\n acosh(x: T): T;\n atanh(x: T): T;\n erf(x: T): T;\n step(x: T, alpha: number): T;\n relu(x: T): T;\n relu6(x: T): T;\n elu(x: T): T;\n selu(x: T): T;\n leakyRelu(x: T, alpha: number): T;\n prelu(x: T, alpha: T|TensorLike): T;\n softmax(logits: T, dim: number): T;\n logSoftmax(logits: T, axis: number): T;\n image: {\n resizeBilinear(\n images: T, size: [number, number], alignCorners: boolean): T;\n resizeNearestNeighbor(\n images: T, size: [number, number], alignCorners: boolean): T;\n };\n conv1d(\n x: T, filter: Tensor3D|TensorLike3D, stride: number,\n pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW', dilation: number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T;\n conv2d(\n x: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,\n pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW',\n dilations: [number, number]|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T;\n conv2dTranspose(\n x: T, filter: Tensor4D|TensorLike4D,\n outputShape: [number, number, number, number]|[number, number, number],\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T;\n depthwiseConv2d(\n x: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,\n pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW',\n dilations: [number, number]|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T;\n separableConv2d(\n x: T|TensorLike, depthwiseFilter: Tensor4D|TensorLike4D,\n pointwiseFilter: Tensor4D|TensorLike, strides: [number, number]|number,\n pad: 'valid'|'same', dilation: [number, number]|number,\n dataFormat: 'NHWC'|'NCHW'): T;\n maxPool(\n x: T, filterSize: [number, number]|number,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T;\n avgPool(\n x: T, filterSize: [number, number]|number,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T;\n pool(\n input: T, windowShape: [number, number]|number, poolingType: 'avg'|'max',\n padding: 'valid'|'same'|number, diationRate?: [number, number]|number,\n strides?: [number, number]|number): T;\n localResponseNormalization(\n x: T, depthRadius: number, bias: number, alpha: number, beta: number): T;\n unsortedSegmentSum(\n x: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T;\n batchToSpaceND(\n x: T, blockShape: number[], crops: number[][]): T;\n spaceToBatchND(\n x: T, blockShape: number[], paddings: number[][]): T;\n topk(x: T, k: number, sorted: boolean):\n {values: T, indices: T};\n stridedSlice(\n x: Tensor, begin: number[], end: number[], strides: number[],\n beginMask: number, endMask: number, ellipsisMask: number,\n newAxisMask: number, shrinkAxisMask: number): Tensor;\n depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D;\n spectral: {\n fft(x: Tensor): Tensor; ifft(x: Tensor): Tensor; rfft(x: Tensor): Tensor;\n irfft(x: Tensor): Tensor\n };\n}\n\n// For tracking tensor creation and disposal.\nlet trackerFn: () => TensorTracker = null;\n// Used by chaining methods to call into ops.\nlet opHandler: OpHandler = null;\n// Used to warn about deprecated methods.\nlet deprecationWarningFn: (msg: string) => void = null;\n// This here so that we can use this method on dev branches and keep the\n// functionality at master.\n// tslint:disable-next-line:no-unused-expression\n[deprecationWarningFn];\n\n/**\n * An external consumer can register itself as the tensor tracker. This way\n * the Tensor class can notify the tracker for every tensor created and\n * disposed.\n */\nexport function setTensorTracker(fn: () => TensorTracker) {\n trackerFn = fn;\n}\n\n/**\n * An external consumer can register itself as the op handler. This way the\n * Tensor class can have chaining methods that call into ops via the op\n * handler.\n */\nexport function setOpHandler(handler: OpHandler) {\n opHandler = handler;\n}\n\n/**\n * Sets the deprecation warning function to be used by this file. This way the\n * Tensor class can be a leaf but still use the environment.\n */\nexport function setDeprecationWarningFn(fn: (msg: string) => void) {\n deprecationWarningFn = fn;\n}\n\n/**\n * We wrap data id since we use weak map to avoid memory leaks.\n * Since we have our own memory management, we have a reference counter\n * mapping a tensor to its data, so there is always a pointer (even if that\n * data is otherwise garbage collectable).\n * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/\n * Global_Objects/WeakMap\n */\nexport type DataId = object; // object instead of {} to force non-primitive.\n\n// Declare this namespace to make Tensor class augmentation work in google3.\nexport declare namespace Tensor {}\n/**\n * A `tf.Tensor` object represents an immutable, multidimensional array of\n * numbers that has a shape and a data type.\n *\n * See `tf.tensor` for details on how to create a `tf.Tensor`.\n */\n/** @doc {heading: 'Tensors', subheading: 'Classes'} */\nexport class Tensor {\n /** Unique id of this tensor. */\n readonly id: number;\n /**\n * Id of the bucket holding the data for this tensor. Multiple arrays can\n * point to the same bucket (e.g. when calling array.reshape()).\n */\n dataId: DataId;\n /** The shape of the tensor. */\n readonly shape: ShapeMap[R];\n /** Number of elements in the tensor. */\n readonly size: number;\n /** The data type for the array. */\n readonly dtype: DataType;\n /** The rank type for the array (see `Rank` enum). */\n readonly rankType: R;\n\n /** Whether this tensor has been globally kept. */\n kept = false;\n /** The id of the scope this tensor is being tracked in. */\n scopeId: number;\n\n /**\n * Number of elements to skip in each dimension when indexing. See\n * https://docs.scipy.org/doc/numpy/reference/generated/\\\n * numpy.ndarray.strides.html\n */\n readonly strides: number[];\n\n constructor(shape: ShapeMap[R], dtype: DataType, dataId: DataId, id: number) {\n this.shape = shape.slice() as ShapeMap[R];\n this.dtype = dtype || 'float32';\n this.size = util.sizeFromShape(shape);\n this.strides = computeStrides(shape);\n this.dataId = dataId;\n this.id = id;\n this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher') as R;\n }\n\n /** Flatten a Tensor to a 1D array. */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n flatten(): Tensor1D {\n this.throwIfDisposed();\n return this.as1D();\n }\n\n /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`. */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n asScalar(): Scalar {\n this.throwIfDisposed();\n util.assert(this.size === 1, () => 'The array must have only 1 element.');\n return this.reshape([]);\n }\n\n /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n as1D(): Tensor1D {\n this.throwIfDisposed();\n return this.reshape([this.size]);\n }\n\n /**\n * Converts a `tf.Tensor` to a `tf.Tensor2D`.\n *\n * @param rows Number of rows in `tf.Tensor2D`.\n * @param columns Number of columns in `tf.Tensor2D`.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n as2D(rows: number, columns: number): Tensor2D {\n this.throwIfDisposed();\n return this.reshape([rows, columns]);\n }\n\n /**\n * Converts a `tf.Tensor` to a `tf.Tensor3D`.\n *\n * @param rows Number of rows in `tf.Tensor3D`.\n * @param columns Number of columns in `tf.Tensor3D`.\n * @param depth Depth of `tf.Tensor3D`.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n as3D(rows: number, columns: number, depth: number): Tensor3D {\n this.throwIfDisposed();\n return this.reshape([rows, columns, depth]);\n }\n\n /**\n * Converts a `tf.Tensor` to a `tf.Tensor4D`.\n *\n * @param rows Number of rows in `tf.Tensor4D`.\n * @param columns Number of columns in `tf.Tensor4D`.\n * @param depth Depth of `tf.Tensor4D`.\n * @param depth2 4th dimension of `tf.Tensor4D`.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n as4D(rows: number, columns: number, depth: number, depth2: number): Tensor4D {\n this.throwIfDisposed();\n return this.reshape([rows, columns, depth, depth2]);\n }\n\n /**\n * Converts a `tf.Tensor` to a `tf.Tensor5D`.\n *\n * @param rows Number of rows in `tf.Tensor5D`.\n * @param columns Number of columns in `tf.Tensor5D`.\n * @param depth Depth of `tf.Tensor5D`.\n * @param depth2 4th dimension of `tf.Tensor5D`.\n * @param depth3 5th dimension of 'tf.Tensor5D'\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n as5D(\n rows: number, columns: number, depth: number, depth2: number,\n depth3: number): Tensor5D {\n this.throwIfDisposed();\n return this.reshape([rows, columns, depth, depth2, depth3]);\n }\n\n /**\n * Casts a `tf.Tensor` to a specified dtype.\n *\n * @param dtype Data-type to cast the tensor to.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n asType(this: T, dtype: DataType): T {\n this.throwIfDisposed();\n return opHandler.cast(this, dtype);\n }\n\n get rank(): number {\n return this.shape.length;\n }\n\n /**\n * Returns a promise of `tf.TensorBuffer` that holds the underlying data.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n async buffer(): Promise> {\n const vals = await this.data();\n return opHandler.buffer(this.shape, this.dtype as D, vals);\n }\n\n /** Returns a `tf.TensorBuffer` that holds the underlying data. */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n bufferSync(): TensorBuffer {\n return opHandler.buffer(this.shape, this.dtype as D, this.dataSync());\n }\n\n /**\n * Returns the tensor data as a nested array. The transfer of data is done\n * asynchronously.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n async array(): Promise {\n const vals = await this.data();\n return toNestedArray(this.shape, vals) as ArrayMap[R];\n }\n\n /**\n * Returns the tensor data as a nested array. The transfer of data is done\n * synchronously.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n arraySync(): ArrayMap[R] {\n return toNestedArray(this.shape, this.dataSync()) as ArrayMap[R];\n }\n\n /**\n * Asynchronously downloads the values from the `tf.Tensor`. Returns a\n * promise of `TypedArray` that resolves when the computation has finished.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n async data(): Promise {\n this.throwIfDisposed();\n const data = trackerFn().read(this.dataId);\n if (this.dtype === 'string') {\n const bytes = await data as Uint8Array[];\n try {\n return bytes.map(b => util.decodeString(b)) as DataTypeMap[D];\n } catch {\n throw new Error(\n 'Failed to decode the string bytes into utf-8. ' +\n 'To get the original bytes, call tensor.bytes().');\n }\n }\n return data as Promise;\n }\n\n /**\n * Synchronously downloads the values from the `tf.Tensor`. This blocks the\n * UI thread until the values are ready, which can cause performance issues.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n dataSync(): DataTypeMap[D] {\n this.throwIfDisposed();\n const data = trackerFn().readSync(this.dataId);\n if (this.dtype === 'string') {\n try {\n return (data as Uint8Array[]).map(b => util.decodeString(b)) as\n DataTypeMap[D];\n } catch {\n throw new Error(\n 'Failed to decode the string bytes into utf-8. ' +\n 'To get the original bytes, call tensor.bytes().');\n }\n }\n return data as DataTypeMap[D];\n }\n\n /** Returns the underlying bytes of the tensor's data. */\n async bytes(): Promise {\n this.throwIfDisposed();\n const data = await trackerFn().read(this.dataId);\n if (this.dtype === 'string') {\n return data as Uint8Array[];\n } else {\n return new Uint8Array((data as TypedArray).buffer);\n }\n }\n\n /**\n * Disposes `tf.Tensor` from memory.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n dispose(): void {\n if (this.isDisposed) {\n return;\n }\n trackerFn().disposeTensor(this);\n this.isDisposedInternal = true;\n }\n\n protected isDisposedInternal = false;\n get isDisposed(): boolean {\n return this.isDisposedInternal;\n }\n\n private throwIfDisposed() {\n if (this.isDisposed) {\n throw new Error(`Tensor is disposed.`);\n }\n }\n\n /** Casts the array to type `float32` */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n toFloat(this: T): T {\n return this.asType('float32');\n }\n\n /** Casts the array to type `int32` */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n toInt() {\n return this.asType('int32');\n }\n\n /** Casts the array to type `bool` */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n toBool() {\n return this.asType('bool');\n }\n\n /**\n * Prints the `tf.Tensor`. See `tf.print` for details.\n *\n * @param verbose Whether to print verbose information about the tensor,\n * including dtype and size.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n print(verbose = false): void {\n return opHandler.print(this, verbose);\n }\n\n /**\n * Reshapes the tensor into the provided shape.\n * See `tf.reshape` for more details.\n *\n * @param newShape An array of integers defining the output tensor shape.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n reshape(newShape: ShapeMap[R2]): Tensor {\n this.throwIfDisposed();\n return opHandler.reshape(this, newShape);\n }\n\n /**\n * Reshapes the tensor into the shape of the provided tensor.\n *\n * @param x The tensor of required shape.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n reshapeAs(x: T): T {\n this.throwIfDisposed();\n return this.reshape(x.shape) as T;\n }\n\n /**\n * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension\n * into the tensor's shape. See `tf.expandDims` for details.\n *\n * @param axis The dimension index at which to insert shape of 1. Defaults to\n * 0 (the first dimension).\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n expandDims(axis = 0): Tensor {\n return opHandler.expandDims(this, axis);\n }\n\n /**\n * Returns the cumulative sum of the `tf.Tensor` along `axis`.\n *\n * @param axis The axis along which to sum. Optional. Defaults to 0.\n * @param exclusive Whether to perform exclusive cumulative sum. Defaults to\n * false. If set to true then the sum of each tensor entry does not\n * include its own value, but only the values previous to it along the\n * specified axis.\n * @param reverse Whether to sum in the opposite direction. Defaults to\n * false.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n cumsum(axis = 0, exclusive = false, reverse = false): T {\n return opHandler.cumsum(this, axis, exclusive, reverse);\n }\n\n /**\n * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape.\n * See `tf.squeeze` for more details.\n *\n * @param axis A list of numbers. If specified, only squeezes the\n * dimensions listed. The dimension index starts at 0. It is an error to\n * squeeze a dimension that is not 1.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n squeeze(axis?: number[]): T {\n this.throwIfDisposed();\n return opHandler.squeeze(this, axis);\n }\n\n /** Returns a copy of the tensor. See `tf.clone` for details. */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n clone(this: T): T {\n this.throwIfDisposed();\n return opHandler.clone(this);\n }\n\n /**\n * Returns a human-readable description of the tensor. Useful for logging.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n toString(verbose = false): string {\n const vals = this.dataSync();\n return tensorToString(vals, this.shape, this.dtype, verbose);\n }\n\n // Below is chain API that is not exposed to docs to avoid repetition. To\n // expose a method, move it above this comment and add @doc and jsdoc.\n\n gather(this: T, indices: Tensor|TensorLike, axis = 0): T {\n this.throwIfDisposed();\n return opHandler.gather(this, indices, axis);\n }\n\n matMul(\n this: T, b: T|TensorLike, transposeA = false, transposeB = false): T {\n this.throwIfDisposed();\n return opHandler.matMul(this, b, transposeA, transposeB);\n }\n dot(b: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.dot(this, b);\n }\n norm(\n ord: number|'euclidean'|'fro' = 'euclidean', axis: number|number[] = null,\n keepDims = false): Tensor {\n this.throwIfDisposed();\n return opHandler.norm(this, ord, axis, keepDims);\n }\n slice>(\n this: T, begin: number|number[], size?: number|number[]): T {\n this.throwIfDisposed();\n return opHandler.slice(this, begin, size);\n }\n reverse(this: T, axis?: number|number[]): T {\n this.throwIfDisposed();\n return opHandler.reverse(this, axis);\n }\n concat(this: T, x: T|Array, axis = 0): T {\n this.throwIfDisposed();\n if (x instanceof Tensor) {\n x = [x];\n }\n return opHandler.concat([this, ...x], axis);\n }\n split(this: T, numOrSizeSplits: number[]|number, axis = 0):\n T[] {\n this.throwIfDisposed();\n return opHandler.split(this, numOrSizeSplits, axis);\n }\n stack(x: Tensor, axis = 0): Tensor {\n return opHandler.stack([this, x], axis);\n }\n unstack(axis = 0): Tensor[] {\n return opHandler.unstack(this, axis);\n }\n /**\n * @deprecated Use `tf.batchNorm` instead, and note the positional argument\n * change of scale, offset, and varianceEpsilon.\n */\n batchNormalization(\n mean: Tensor|Tensor1D|TensorLike,\n variance: Tensor|Tensor1D|TensorLike, varianceEpsilon = .001,\n scale?: Tensor|Tensor1D|TensorLike,\n offset?: Tensor|Tensor1D|TensorLike): Tensor {\n deprecationWarningFn(\n 'tf.batchNormalization() is going away. ' +\n 'Use tf.batchNorm() instead, and note the positional argument change ' +\n 'of scale, offset, and varianceEpsilon');\n return this.batchNorm(mean, variance, offset, scale, varianceEpsilon);\n }\n\n // Reduction ops.\n all(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.all(this, axis, keepDims);\n }\n any(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.any(this, axis, keepDims);\n }\n logSumExp(axis: number|number[] = null, keepDims = false):\n T {\n this.throwIfDisposed();\n return opHandler.logSumExp(this, axis, keepDims);\n }\n sum(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.sum(this, axis, keepDims);\n }\n prod(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.prod(this, axis, keepDims);\n }\n mean(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.mean(this, axis, keepDims);\n }\n min(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.min(this, axis, keepDims);\n }\n max(axis: number|number[] = null, keepDims = false): T {\n this.throwIfDisposed();\n return opHandler.max(this, axis, keepDims);\n }\n argMin(axis: number = null): T {\n this.throwIfDisposed();\n return opHandler.argMin(this, axis);\n }\n argMax(axis: number = null): T {\n this.throwIfDisposed();\n return opHandler.argMax(this, axis);\n }\n\n // Transformations\n cast(dtype: DataType): T {\n this.throwIfDisposed();\n return opHandler.cast(this as T, dtype);\n }\n\n // Binary ops.\n addStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.addStrict(this, x);\n }\n atan2(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.atan2(this, x);\n }\n sub(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.sub(this, x);\n }\n subStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.subStrict(this, x);\n }\n pow(this: T, exp: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.pow(this, exp);\n }\n powStrict(exp: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.powStrict(this, exp);\n }\n mul(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.mul(this, x);\n }\n mulStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.mulStrict(this, x);\n }\n floorDiv(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.floorDiv(this, x);\n }\n divStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.divStrict(this, x);\n }\n minimum(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.minimum(this, x);\n }\n minimumStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.minimumStrict(this, x);\n }\n maximum(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.maximum(this, x);\n }\n maximumStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.maximumStrict(this, x);\n }\n mod(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.mod(this, x);\n }\n modStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.modStrict(this, x);\n }\n squaredDifferenceStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.squaredDifferenceStrict(this, x);\n }\n\n // Compare ops.\n\n notEqual(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.notEqual(this, x);\n }\n notEqualStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.notEqualStrict(this, x);\n }\n less(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.less(this, x);\n }\n lessStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.lessStrict(this, x);\n }\n equal(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.equal(this, x);\n }\n equalStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.equalStrict(this, x);\n }\n lessEqual(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.lessEqual(this, x);\n }\n lessEqualStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.lessEqualStrict(this, x);\n }\n greater(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.greater(this, x);\n }\n greaterStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.greaterStrict(this, x);\n }\n greaterEqual(x: Tensor|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.greaterEqual(this, x);\n }\n greaterEqualStrict(this: T, x: T|TensorLike): T {\n this.throwIfDisposed();\n return opHandler.greaterEqualStrict(this, x);\n }\n\n // Compare ops.\n logicalAnd(x: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.logicalAnd(this, x);\n }\n logicalOr(x: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.logicalOr(this, x);\n }\n logicalNot(this: T): T {\n this.throwIfDisposed();\n return opHandler.logicalNot(this);\n }\n logicalXor(x: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.logicalXor(this, x);\n }\n where(condition: Tensor|TensorLike, x: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.where(condition, this, x);\n }\n\n // Unary ops.\n neg(this: T): T {\n this.throwIfDisposed();\n return opHandler.neg(this);\n }\n ceil(this: T): T {\n this.throwIfDisposed();\n return opHandler.ceil(this);\n }\n floor(this: T): T {\n this.throwIfDisposed();\n return opHandler.floor(this);\n }\n sign(this: T): T {\n this.throwIfDisposed();\n return opHandler.sign(this);\n }\n isNaN(this: T): T {\n this.throwIfDisposed();\n return opHandler.isNaN(this);\n }\n isInf(this: T): T {\n this.throwIfDisposed();\n return opHandler.isInf(this);\n }\n isFinite(this: T): T {\n this.throwIfDisposed();\n return opHandler.isFinite(this);\n }\n exp(this: T): T {\n this.throwIfDisposed();\n return opHandler.exp(this);\n }\n expm1(this: T): T {\n this.throwIfDisposed();\n return opHandler.expm1(this);\n }\n log(this: T): T {\n this.throwIfDisposed();\n return opHandler.log(this);\n }\n log1p(this: T): T {\n this.throwIfDisposed();\n return opHandler.log1p(this);\n }\n sqrt(this: T): T {\n this.throwIfDisposed();\n return opHandler.sqrt(this);\n }\n rsqrt(this: T): T {\n this.throwIfDisposed();\n return opHandler.rsqrt(this);\n }\n square(this: T): T {\n this.throwIfDisposed();\n return opHandler.square(this);\n }\n reciprocal(this: T): T {\n this.throwIfDisposed();\n return opHandler.reciprocal(this);\n }\n abs(this: T): T {\n this.throwIfDisposed();\n return opHandler.abs(this);\n }\n clipByValue(min: number, max: number): Tensor {\n this.throwIfDisposed();\n return opHandler.clipByValue(this, min, max);\n }\n relu(this: T): T {\n this.throwIfDisposed();\n return opHandler.relu(this);\n }\n relu6(this: T): T {\n this.throwIfDisposed();\n return opHandler.relu6(this);\n }\n elu(this: T): T {\n this.throwIfDisposed();\n return opHandler.elu(this);\n }\n selu(this: T): T {\n this.throwIfDisposed();\n return opHandler.selu(this);\n }\n leakyRelu(alpha = 0.2): Tensor {\n this.throwIfDisposed();\n return opHandler.leakyRelu(this, alpha);\n }\n prelu(alpha: Tensor|TensorLike): Tensor {\n this.throwIfDisposed();\n return opHandler.prelu(this, alpha);\n }\n sigmoid(this: T): T {\n this.throwIfDisposed();\n return opHandler.sigmoid(this);\n }\n logSigmoid(this: T): T {\n this.throwIfDisposed();\n return opHandler.logSigmoid(this);\n }\n softplus(this: T): T {\n this.throwIfDisposed();\n return opHandler.softplus(this);\n }\n zerosLike(this: T): T {\n this.throwIfDisposed();\n return opHandler.zerosLike(this);\n }\n onesLike(this: T): T {\n this.throwIfDisposed();\n return opHandler.onesLike(this);\n }\n sin(this: T): T {\n this.throwIfDisposed();\n return opHandler.sin(this);\n }\n cos(this: T): T {\n this.throwIfDisposed();\n return opHandler.cos(this);\n }\n tan(this: T): T {\n this.throwIfDisposed();\n return opHandler.tan(this);\n }\n asin(this: T): T {\n this.throwIfDisposed();\n return opHandler.asin(this);\n }\n acos(this: T): T {\n this.throwIfDisposed();\n return opHandler.acos(this);\n }\n atan(this: T): T {\n this.throwIfDisposed();\n return opHandler.atan(this);\n }\n sinh(this: T): T {\n this.throwIfDisposed();\n return opHandler.sinh(this);\n }\n cosh(this: T): T {\n this.throwIfDisposed();\n return opHandler.cosh(this);\n }\n tanh(this: T): T {\n this.throwIfDisposed();\n return opHandler.tanh(this);\n }\n asinh(this: T): T {\n this.throwIfDisposed();\n return opHandler.asinh(this);\n }\n acosh(this: T): T {\n this.throwIfDisposed();\n return opHandler.acosh(this);\n }\n atanh(this: T): T {\n this.throwIfDisposed();\n return opHandler.atanh(this);\n }\n erf(this: T): T {\n this.throwIfDisposed();\n return opHandler.erf(this);\n }\n round(this: T): T {\n this.throwIfDisposed();\n return opHandler.round(this);\n }\n step(this: T, alpha = 0.0): T {\n this.throwIfDisposed();\n return opHandler.step(this, alpha);\n }\n softmax(this: T, dim = -1): T {\n this.throwIfDisposed();\n return opHandler.softmax(this, dim);\n }\n logSoftmax(this: T, axis = -1): T {\n this.throwIfDisposed();\n return opHandler.logSoftmax(this, axis);\n }\n\n // Image ops.\n resizeBilinear(\n this: T, newShape2D: [number, number], alignCorners = false): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.image.resizeBilinear(this, newShape2D, alignCorners);\n }\n\n resizeNearestNeighbor(\n this: T, newShape2D: [number, number], alignCorners = false): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.image.resizeNearestNeighbor(\n this, newShape2D, alignCorners);\n }\n\n // Convolutions.\n conv1d(\n this: T, filter: Tensor3D|TensorLike3D, stride: number,\n pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.conv1d(\n this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);\n }\n conv2d(\n this: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,\n pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n dilations: [number, number]|number = [1, 1],\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.conv2d(\n this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);\n }\n conv2dTranspose(\n this: T, filter: Tensor4D|TensorLike4D,\n outputShape: [number, number, number, number]|[number, number, number],\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.conv2dTranspose(\n this, filter, outputShape, strides, pad, dimRoundingMode);\n }\n depthwiseConv2D(\n this: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,\n pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n dilations: [number, number]|number = [1, 1],\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.depthwiseConv2d(\n this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);\n }\n\n separableConv2d(\n this: T|TensorLike, depthwiseFilter: Tensor4D|TensorLike4D,\n pointwiseFilter: Tensor4D|TensorLike, strides: [number, number]|number,\n pad: 'valid'|'same', dilation: [number, number]|number = [1, 1],\n dataFormat: 'NHWC'|'NCHW' = 'NHWC'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.separableConv2d(\n this, depthwiseFilter, pointwiseFilter, strides, pad, dilation,\n dataFormat);\n }\n\n // Pooling.\n avgPool(\n this: T, filterSize: [number, number]|number,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.avgPool(this, filterSize, strides, pad, dimRoundingMode);\n }\n maxPool(\n this: T, filterSize: [number, number]|number,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.maxPool(this, filterSize, strides, pad, dimRoundingMode);\n }\n localResponseNormalization(\n this: T, radius = 5, bias = 1, alpha = 1, beta = 0.5): T {\n return opHandler.localResponseNormalization(\n this, radius, bias, alpha, beta);\n }\n pool(\n this: T, windowShape: [number, number]|number, poolingType: 'max'|'avg',\n padding: 'valid'|'same'|number, dilationRate?: [number, number]|number,\n strides?: [number, number]|number): T {\n (this as Tensor).throwIfDisposed();\n return opHandler.pool(\n this, windowShape, poolingType, padding, dilationRate, strides);\n }\n\n variable(trainable = true, name?: string, dtype?: DataType): Variable {\n this.throwIfDisposed();\n return trackerFn().makeVariable(this, trainable, name, dtype) as\n Variable;\n }\n\n unsortedSegmentSum(\n this: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T {\n this.throwIfDisposed();\n return opHandler.unsortedSegmentSum(this, segmentIds, numSegments);\n }\n\n batchToSpaceND(\n this: T, blockShape: number[], crops: number[][]): T {\n this.throwIfDisposed();\n return opHandler.batchToSpaceND(this, blockShape, crops);\n }\n\n spaceToBatchND(\n this: T, blockShape: number[], paddings: number[][]): T {\n this.throwIfDisposed();\n return opHandler.spaceToBatchND(this, blockShape, paddings);\n }\n\n topk(this: T, k = 1, sorted = true):\n {values: T, indices: T} {\n this.throwIfDisposed();\n return opHandler.topk(this, k, sorted);\n }\n\n stridedSlice(\n this: Tensor, begin: number[], end: number[], strides: number[],\n beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,\n shrinkAxisMask = 0): Tensor {\n this.throwIfDisposed();\n return opHandler.stridedSlice(\n this, begin, end, strides, beginMask, endMask, ellipsisMask,\n newAxisMask, shrinkAxisMask);\n }\n\n depthToSpace(this: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):\n Tensor4D {\n this.throwIfDisposed();\n return opHandler.depthToSpace(this, blockSize, dataFormat);\n }\n\n fft(this: Tensor): Tensor {\n this.throwIfDisposed();\n return opHandler.spectral.fft(this);\n }\n\n ifft(this: Tensor): Tensor {\n this.throwIfDisposed();\n return opHandler.spectral.ifft(this);\n }\n\n rfft(this: Tensor): Tensor {\n this.throwIfDisposed();\n return opHandler.spectral.rfft(this);\n }\n\n irfft(this: Tensor): Tensor {\n this.throwIfDisposed();\n return opHandler.spectral.irfft(this);\n }\n}\nObject.defineProperty(Tensor, Symbol.hasInstance, {\n value: (instance: Tensor) => {\n return !!instance && instance.dataId != null && instance.shape != null &&\n instance.dtype != null;\n }\n});\n\nexport interface NumericTensor extends Tensor {\n dtype: NumericDataType;\n dataSync(): DataTypeMap[D];\n data(): Promise;\n}\n\nexport interface StringTensor extends Tensor {\n dtype: 'string';\n dataSync(): DataTypeMap[D];\n data(): Promise;\n}\n\n/** @doclink Tensor */\nexport type Scalar = Tensor;\n/** @doclink Tensor */\nexport type Tensor1D = Tensor;\n/** @doclink Tensor */\nexport type Tensor2D = Tensor;\n/** @doclink Tensor */\nexport type Tensor3D = Tensor;\n/** @doclink Tensor */\nexport type Tensor4D = Tensor;\n/** @doclink Tensor */\nexport type Tensor5D = Tensor;\n/** @doclink Tensor */\nexport type Tensor6D = Tensor;\n\n/**\n * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.\n */\n/** @doc {heading: 'Tensors', subheading: 'Classes'} */\nexport class Variable extends Tensor {\n name: string;\n\n constructor(\n initialValue: Tensor, public trainable: boolean, name: string,\n tensorId: number) {\n super(\n initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);\n this.name = name;\n }\n\n /**\n * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have\n * the same shape and dtype as the old `tf.Tensor`.\n *\n * @param newValue New tensor to be assigned to this variable.\n */\n /** @doc {heading: 'Tensors', subheading: 'Classes'} */\n assign(newValue: Tensor): void {\n if (newValue.dtype !== this.dtype) {\n throw new Error(\n `dtype of the new value (${newValue.dtype}) and ` +\n `previous value (${this.dtype}) must match`);\n }\n if (!util.arraysEqual(newValue.shape, this.shape)) {\n throw new Error(\n `shape of the new value (${newValue.shape}) and ` +\n `previous value (${this.shape}) must match`);\n }\n trackerFn().disposeTensor(this);\n this.dataId = newValue.dataId;\n trackerFn().incRef(this, null /* backend */);\n }\n\n dispose(): void {\n trackerFn().disposeVariable(this);\n this.isDisposedInternal = true;\n }\n}\n\nObject.defineProperty(Variable, Symbol.hasInstance, {\n value: (instance: Variable) => {\n return instance instanceof Tensor && instance.assign != null &&\n instance.assign instanceof Function;\n }\n});\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/** @docalias number[] */\nexport interface ShapeMap {\n R0: number[];\n R1: [number];\n R2: [number, number];\n R3: [number, number, number];\n R4: [number, number, number, number];\n R5: [number, number, number, number, number];\n R6: [number, number, number, number, number, number];\n}\n\n/** @docalias number[] */\nexport interface ArrayMap {\n R0: number;\n R1: number[];\n R2: number[][];\n R3: number[][][];\n R4: number[][][][];\n R5: number[][][][][];\n R6: number[][][][][][];\n}\n\nexport interface DataTypeMap {\n float32: Float32Array;\n int32: Int32Array;\n bool: Uint8Array;\n complex64: Float32Array;\n string: string[];\n}\n\nexport interface SingleValueMap {\n bool: boolean;\n int32: number;\n float32: number;\n complex64: number;\n string: string;\n}\n\n/** @docalias 'float32'|'int32'|'bool'|'complex64'|'string' */\nexport type DataType = keyof DataTypeMap;\nexport type NumericDataType = 'float32'|'int32'|'bool'|'complex64';\nexport type TypedArray = Float32Array|Int32Array|Uint8Array;\n/** Tensor data used in tensor creation and user-facing API. */\nexport type DataValues = DataTypeMap[DataType];\n/** The underlying tensor data that gets stored in a backend. */\nexport type BackendValues = Float32Array|Int32Array|Uint8Array|Uint8Array[];\n\nexport enum Rank {\n R0 = 'R0',\n R1 = 'R1',\n R2 = 'R2',\n R3 = 'R3',\n R4 = 'R4',\n R5 = 'R5',\n R6 = 'R6'\n}\n\nexport type FlatVector = boolean[]|number[]|TypedArray;\nexport type RegularArray =\n T[]|T[][]|T[][][]|T[][][][]|T[][][][][]|T[][][][][][];\n\n// tslint:disable-next-line:no-any\nexport interface RecursiveArray {\n [index: number]: T|RecursiveArray;\n}\n\n// Looks for upcasting types. Used, for example, in operations with mixed dtype\n// inputs.\nenum UpcastInt32AndMap {\n 'float32' = 'float32',\n 'int32' = 'int32',\n 'bool' = 'int32',\n 'complex64' = 'complex64'\n}\n\nenum UpcastBoolAndMap {\n 'float32' = 'float32',\n 'int32' = 'int32',\n 'bool' = 'bool',\n 'complex64' = 'complex64'\n}\n\nenum UpcastFloat32AndMap {\n 'float32' = 'float32',\n 'int32' = 'float32',\n 'bool' = 'float32',\n 'complex64' = 'complex64'\n}\n\nenum UpcastComplex64AndMap {\n 'float32' = 'complex64',\n 'int32' = 'complex64',\n 'bool' = 'complex64',\n 'complex64' = 'complex64'\n}\n\nconst upcastTypeMap = {\n 'float32': UpcastFloat32AndMap,\n 'int32': UpcastInt32AndMap,\n 'bool': UpcastBoolAndMap,\n 'complex64': UpcastComplex64AndMap\n};\n\nexport function upcastType(typeA: DataType, typeB: DataType): DataType {\n if (typeA === 'string' || typeB === 'string') {\n if (typeA === 'string' && typeB === 'string') {\n return 'string';\n }\n throw new Error(`Can not upcast ${typeA} with ${typeB}`);\n }\n return upcastTypeMap[typeA][typeB];\n}\n\n/** Returns the output type after summation. */\nexport function sumOutType(type: DataType): DataType {\n return upcastType(type, 'int32');\n}\n\n/** @docalias TypedArray|Array */\nexport type TensorLike =\n TypedArray|number|boolean|string|RecursiveArray|\n RecursiveArray|RecursiveArray|Uint8Array[];\nexport type ScalarLike = number|boolean|string|Uint8Array;\n/** @docalias TypedArray|Array */\nexport type TensorLike1D = TypedArray|number[]|boolean[]|string[]|Uint8Array[];\n/** @docalias TypedArray|Array */\nexport type TensorLike2D = TypedArray|number[]|number[][]|boolean[]|boolean[][]|\n string[]|string[][]|Uint8Array[]|Uint8Array[][];\n/** @docalias TypedArray|Array */\nexport type TensorLike3D = TypedArray|number[]|number[][][]|boolean[]|\n boolean[][][]|string[]|string[][][]|Uint8Array[]|Uint8Array[][][];\n/** @docalias TypedArray|Array */\nexport type TensorLike4D = TypedArray|number[]|number[][][][]|boolean[]|\n boolean[][][][]|string[]|string[][][][]|Uint8Array[]|Uint8Array[][][][];\n/** @docalias TypedArray|Array */\nexport type TensorLike5D =\n TypedArray|number[]|number[][][][][]|boolean[]|boolean[][][][][]|string[]|\n string[][][][][]|Uint8Array[]|Uint8Array[][][][][];\n/** @docalias TypedArray|Array */\nexport type TensorLike6D =\n TypedArray|number[]|number[][][][][][]|boolean[]|boolean[][][][][][]|\n string[]|string[][][][][][]|Uint8Array[]|Uint8Array[][][][][];\n\n/** Type for representing image dat in Uint8Array type. */\nexport interface PixelData {\n width: number;\n height: number;\n data: Uint8Array;\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from './tensor';\nimport {TensorContainer, TensorContainerArray} from './tensor_types';\nimport {upcastType} from './types';\nimport {assert} from './util';\n\nexport function makeTypesMatch(a: T, b: T): [T, T] {\n if (a.dtype === b.dtype) {\n return [a, b];\n }\n const dtype = upcastType(a.dtype, b.dtype);\n return [a.cast(dtype), b.cast(dtype)];\n}\n\nexport function assertTypesMatch(a: Tensor, b: Tensor): void {\n assert(\n a.dtype === b.dtype,\n () => `The dtypes of the first(${a.dtype}) and` +\n ` second(${b.dtype}) input must match`);\n}\n\nexport function isTensorInList(tensor: Tensor, tensorList: Tensor[]): boolean {\n return tensorList.some(x => x.id === tensor.id);\n}\n\n/**\n * Extracts any `Tensor`s found within the provided object.\n *\n * @param container an object that may be a `Tensor` or may directly contain\n * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it\n * is safe to pass any object here, except that `Promise`s are not\n * supported.\n * @returns An array of `Tensors` found within the passed object. If the\n * argument is simply a `Tensor', a list containing that `Tensor` is\n * returned. If the object is not a `Tensor` or does not\n * contain `Tensors`, an empty list is returned.\n */\nexport function getTensorsInContainer(result: TensorContainer): Tensor[] {\n const list: Tensor[] = [];\n const seen = new Set<{}|void>();\n walkTensorContainer(result, list, seen);\n return list;\n}\n\nfunction walkTensorContainer(\n container: TensorContainer, list: Tensor[], seen: Set<{}|void>): void {\n if (container == null) {\n return;\n }\n if (container instanceof Tensor) {\n list.push(container);\n return;\n }\n if (!isIterable(container)) {\n return;\n }\n // Iteration over keys works also for arrays.\n const iterable = container as TensorContainerArray;\n for (const k in iterable) {\n const val = iterable[k];\n if (!seen.has(val)) {\n seen.add(val);\n walkTensorContainer(val, list, seen);\n }\n }\n}\n\n// tslint:disable-next-line:no-any\nfunction isIterable(obj: any): boolean {\n return Array.isArray(obj) || typeof obj === 'object';\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';\nimport {Environment, setEnvironmentGlobal} from './environment';\nimport {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, TensorInfo} from './kernel_registry';\nimport {Profiler} from './profiler';\nimport {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape';\nimport {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';\nimport {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';\nimport {getTensorsInContainer} from './tensor_util';\nimport {BackendValues, DataType, DataValues} from './types';\nimport * as util from './util';\nimport {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';\n\n/**\n * A function that computes an output. The save function is for saving tensors\n * computed in the forward pass, that we need in the backward pass.\n */\nexport type ForwardFunc = (backend: KernelBackend, save?: GradSaveFunc) => T;\n\n/**\n * @docalias (a: Tensor, b: Tensor,..., save?: Function) => {\n * value: Tensor,\n * gradFunc: (dy: Tensor, saved?: NamedTensorMap) => Tensor | Tensor[]\n * }\n */\nexport type CustomGradientFunc =\n (...inputs: Array) => {\n value: T;\n gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[];\n };\n\nexport type MemoryInfo = {\n numTensors: number; numDataBuffers: number; numBytes: number;\n unreliable?: boolean; reasons: string[];\n};\n\ntype KernelProfile = {\n name: string; bytesAdded: number; totalBytesSnapshot: number;\n tensorsAdded: number;\n totalTensorsSnapshot: number;\n inputShapes: number[][];\n outputShapes: number[][];\n};\n\nexport type ProfileInfo = {\n newBytes: number; newTensors: number; peakBytes: number;\n kernels: KernelProfile[];\n result: TensorContainer;\n};\n\nexport interface TimingInfo extends BackendTimingInfo {\n wallMs: number;\n}\n\n/** @docalias Function */\nexport type ScopeFn = () => T;\n\ninterface ScopeState {\n track: Tensor[];\n name: string;\n id: number;\n}\n\nclass EngineState {\n // Public since optimizers will use it.\n registeredVariables: NamedVariableMap = {};\n\n nextTapeNodeId = 0;\n numBytes = 0;\n numTensors = 0;\n numStringTensors = 0;\n numDataBuffers = 0;\n\n activeTape: TapeNode[];\n // Number of nested tf.grad() statements when computing higher-order\n // gradients. E.g. `1` for first-order gradients and `2` for second-order\n // gradients. Used to track if the tape should be removed after a backprop.\n gradientDepth = 0;\n // Number of nested kernel calls. When kernel depth is greater than 1, we turn\n // off the tape.\n kernelDepth = 0;\n\n // Keep Tensors that parallel the tapes.\n activeScope: ScopeState;\n scopeStack: ScopeState[] = [];\n /**\n * Keeps track of the number of data moves during a kernel execution. We\n * maintain a stack since kernels can call other kernels, recursively.\n */\n numDataMovesStack: number[] = [];\n nextScopeId = 0;\n\n tensorInfo = new WeakMap();\n\n profiling = false;\n activeProfile: ProfileInfo =\n {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};\n\n dispose() {\n for (const variableName in this.registeredVariables) {\n this.registeredVariables[variableName].dispose();\n }\n }\n}\n\nexport class Engine implements TensorTracker, DataMover {\n state: EngineState;\n backendName: string;\n registry: {[id: string]: KernelBackend} = {};\n registryFactory: {\n [id: string]: {\n factory: () => KernelBackend | Promise,\n priority: number\n }\n } = {};\n\n private profiler: Profiler;\n private backendInstance: KernelBackend;\n private pendingBackendInit: Promise;\n private pendingBackendInitId = 0;\n\n constructor(public ENV: Environment) {\n this.state = new EngineState();\n }\n\n async ready(): Promise {\n if (this.pendingBackendInit != null) {\n return this.pendingBackendInit.then(() => {});\n }\n if (this.backendInstance != null) {\n return;\n }\n const sortedBackends = this.getSortedBackends();\n\n for (let i = 0; i < sortedBackends.length; i++) {\n const backendName = sortedBackends[i];\n const success = await this.initializeBackend(backendName).success;\n if (success) {\n await this.setBackend(backendName);\n return;\n }\n }\n\n throw new Error(\n `Could not initialize any backends, all backend initializations ` +\n `failed.`);\n }\n\n get backend(): KernelBackend {\n if (this.pendingBackendInit != null) {\n throw new Error(\n `Backend '${this.backendName}' has not yet been initialized. Make ` +\n `sure to await tf.ready() or await tf.setBackend() before calling ` +\n `other methods`);\n }\n if (this.backendInstance == null) {\n const {name, asyncInit} = this.initializeBackendsAndReturnBest();\n if (asyncInit) {\n throw new Error(\n `The highest priority backend '${name}' has not yet been ` +\n `initialized. Make sure to await tf.ready() or ` +\n `await tf.setBackend() before calling other methods`);\n }\n this.setBackend(name);\n }\n return this.backendInstance;\n }\n\n backendNames(): string[] {\n return Object.keys(this.registryFactory);\n }\n\n findBackend(backendName: string): KernelBackend {\n if (!(backendName in this.registry)) {\n // If the backend hasn't been initialized but we have a registry entry for\n // it, initialize it and return it.\n if (backendName in this.registryFactory) {\n const {asyncInit} = this.initializeBackend(backendName);\n if (asyncInit) {\n // Backend is not ready yet.\n return null;\n }\n } else {\n return null;\n }\n }\n return this.registry[backendName];\n }\n\n findBackendFactory(backendName: string):\n () => KernelBackend | Promise {\n if (!(backendName in this.registryFactory)) {\n return null;\n }\n return this.registryFactory[backendName].factory;\n }\n\n registerBackend(\n backendName: string,\n factory: () => KernelBackend | Promise,\n priority = 1): boolean {\n if (backendName in this.registryFactory) {\n console.warn(\n `${backendName} backend was already registered. ` +\n `Reusing existing backend factory.`);\n return false;\n }\n this.registryFactory[backendName] = {factory, priority};\n return true;\n }\n\n async setBackend(backendName: string): Promise {\n if (this.registryFactory[backendName] == null) {\n throw new Error(`Backend name '${backendName}' not found in registry`);\n }\n this.backendName = backendName;\n if (this.registry[backendName] == null) {\n this.backendInstance = null;\n const {success, asyncInit} = this.initializeBackend(backendName);\n const result = asyncInit ? await success : success;\n if (!result) {\n return false;\n }\n }\n this.backendInstance = this.registry[backendName];\n this.setupRegisteredKernels();\n // Reset the profiler.\n this.profiler = new Profiler(this.backendInstance);\n\n return true;\n }\n\n private setupRegisteredKernels(): void {\n const kernels = getKernelsForBackend(this.backendName);\n kernels.forEach(kernel => {\n if (kernel.setupFunc != null) {\n kernel.setupFunc(this.backendInstance);\n }\n });\n }\n\n private disposeRegisteredKernels(backendName: string): void {\n const kernels = getKernelsForBackend(backendName);\n kernels.forEach(kernel => {\n if (kernel.disposeFunc != null) {\n kernel.disposeFunc(this.registry[backendName]);\n }\n });\n }\n\n /**\n * Initializes a backend by looking up the backend name in the factory\n * registry and calling the factory method. Returns a boolean representing\n * whether the initialization of the backend suceeded. Throws an error if\n * there is no backend in the factory registry.\n */\n private initializeBackend(backendName: string):\n {success: boolean|Promise, asyncInit: boolean} {\n const registryFactoryEntry = this.registryFactory[backendName];\n if (registryFactoryEntry == null) {\n throw new Error(\n `Cannot initialize backend ${backendName}, no registration found.`);\n }\n\n try {\n const backend = registryFactoryEntry.factory();\n // Test if the factory returns a promise.\n if (Promise.resolve(backend) === backend) {\n const promiseId = ++this.pendingBackendInitId;\n const success =\n backend\n .then(backendInstance => {\n // Outdated promise. Another backend was set in the meantime.\n if (promiseId < this.pendingBackendInitId) {\n return false;\n }\n this.registry[backendName] = backendInstance;\n this.pendingBackendInit = null;\n return true;\n })\n .catch(err => {\n // Outdated promise. Another backend was set in the meantime.\n if (promiseId < this.pendingBackendInitId) {\n return false;\n }\n this.pendingBackendInit = null;\n console.warn(\n `Initialization of backend ${backendName} failed`);\n console.warn(err.stack || err.message);\n return false;\n });\n this.pendingBackendInit = success;\n return {success, asyncInit: true};\n } else {\n this.registry[backendName] = backend as KernelBackend;\n return {success: true, asyncInit: false};\n }\n } catch (err) {\n console.warn(`Initialization of backend ${backendName} failed`);\n console.warn(err.stack || err.message);\n return {success: false, asyncInit: false};\n }\n }\n\n removeBackend(backendName: string): void {\n if (!(backendName in this.registryFactory)) {\n throw new Error(`${backendName} backend not found in registry`);\n }\n if (this.backendName === backendName && this.pendingBackendInit != null) {\n // There is a pending promise of the backend we want to remove. Make it\n // obsolete.\n this.pendingBackendInitId++;\n }\n\n if (backendName in this.registry) {\n this.disposeRegisteredKernels(backendName);\n this.registry[backendName].dispose();\n delete this.registry[backendName];\n }\n\n delete this.registryFactory[backendName];\n\n // Unset the backend if it is active.\n if (this.backendName === backendName) {\n this.pendingBackendInit = null;\n this.backendName = null;\n this.backendInstance = null;\n }\n }\n\n private getSortedBackends(): string[] {\n if (Object.keys(this.registryFactory).length === 0) {\n throw new Error('No backend found in registry.');\n }\n return Object.keys(this.registryFactory).sort((a: string, b: string) => {\n // Highest priority comes first.\n return this.registryFactory[b].priority -\n this.registryFactory[a].priority;\n });\n }\n\n private initializeBackendsAndReturnBest():\n {name: string, asyncInit: boolean} {\n const sortedBackends = this.getSortedBackends();\n\n for (let i = 0; i < sortedBackends.length; i++) {\n const backendName = sortedBackends[i];\n const {success, asyncInit} = this.initializeBackend(backendName);\n if (asyncInit || success) {\n return {name: backendName, asyncInit};\n }\n }\n throw new Error(\n `Could not initialize any backends, all backend initializations ` +\n `failed.`);\n }\n\n moveData(destBackend: KernelBackend, dataId: DataId) {\n const info = this.state.tensorInfo.get(dataId);\n const srcBackend = info.backend;\n const values = this.readSync(dataId);\n // Delete the tensor from the old backend and move it to the new\n // backend.\n srcBackend.disposeData(dataId);\n info.backend = destBackend;\n destBackend.move(dataId, values, info.shape, info.dtype);\n if (this.shouldCheckForMemLeaks()) {\n // Track the number of moves during a kernel execution to correctly\n // detect memory leaks.\n this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;\n }\n }\n\n tidy(nameOrFn: string|ScopeFn, fn?: ScopeFn):\n T {\n let name: string = null;\n if (fn == null) {\n // Called with only 1 argument.\n if (typeof nameOrFn !== 'function') {\n throw new Error('Please provide a function to tidy()');\n }\n fn = nameOrFn;\n } else {\n // Called with 2 arguments.\n if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {\n throw new Error(\n 'When calling with two arguments, the first argument ' +\n 'to tidy() must be a string');\n }\n if (typeof fn !== 'function') {\n throw new Error(\n 'When calling with two arguments, the 2nd argument ' +\n 'to tidy() must be a function');\n }\n name = nameOrFn as string;\n // TODO(nsthorat,smilkov): Do operation logging and performance\n // profiling.\n }\n let result: T;\n return this.scopedRun(\n () => this.startScope(name), () => this.endScope(result), () => {\n result = fn();\n if (result instanceof Promise) {\n console.error('Cannot return a Promise inside of tidy.');\n }\n return result;\n });\n }\n\n private scopedRun(start: () => void, end: () => void, f: () => T): T {\n start();\n try {\n const res = f();\n end();\n return res;\n } catch (ex) {\n end();\n throw ex;\n }\n }\n\n private static nextTensorId = 0;\n private nextTensorId(): number {\n return Engine.nextTensorId++;\n }\n\n private static nextVariableId = 0;\n private nextVariableId(): number {\n return Engine.nextVariableId++;\n }\n\n /**\n * This method is called instead of the public-facing tensor.clone() when\n * saving a tensor for backwards pass. It makes sure to add the clone\n * operation to the tape regardless of being called inside a kernel\n * execution.\n *\n * This method will go away once all kernels are modularized since we won't\n * need to turn off the tape inside runKernel().\n */\n private clone(x: Tensor): Tensor {\n const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);\n const inputs = {x};\n const grad = (dy: Tensor) => ({x: () => dy.toFloat()});\n const saved: Tensor[] = [];\n this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});\n return y;\n }\n\n /**\n * Execute a kernel with the given name and return the output tensor.\n *\n * @param kernelName The name of the kernel to execute.\n * @param inputs A map of input names to tensors.\n * @param attrs A map of attribute names to their values. An attribute is a\n * primitive (non-tensor) input to the kernel.\n * @param inputsToSave A list of tensors, inputs to save for the backprop\n * computation.\n * @param outputsToSave A list of booleans, specifying which output to save\n * for the backprop computation. These are booleans since the output\n * tensors are not visible to the user.\n */\n runKernel(\n kernelName: string, inputs: NamedTensorMap, attrs: NamedAttrMap,\n inputsToSave?: Tensor[], outputsToSave?: boolean[]): Tensor|Tensor[] {\n const forwardFunc: null = null;\n const backwardsFunc: null = null;\n // Call runKernel as a stop-gap until we modularize all kernels.\n // Once we modularize all kernels, we will remove the existing\n // `runKernelFunc`.\n return this.runKernelFunc(\n forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave,\n outputsToSave);\n }\n\n private shouldCheckForMemLeaks(): boolean {\n return this.ENV.getBool('IS_TEST');\n }\n\n private checkKernelForMemLeak(\n kernelName: string, numDataIdsBefore: number,\n outInfos: TensorInfo[]): void {\n const numDataIdsAfter = this.backend.numDataIds();\n\n // Count the number of data ids associated with the result of the kernel.\n let numOutputDataIds = 0;\n outInfos.forEach(info => {\n // Complex numbers allocate 3 data ids, one for 'real', one for\n // 'imaginary', and one for the container that holds the former two.\n numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);\n });\n\n // Account for the number of moves during kernel execution. A \"data move\"\n // can happen in the middle of a kernel execution, placing a new (key,value)\n // pair in the data storage. Since data moves have net zero effect (we\n // always remove the data from the old backend), we have to cancel them out\n // when detecting memory leaks.\n const numMoves =\n this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];\n const dataIdsLeaked =\n numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;\n if (dataIdsLeaked > 0) {\n throw new Error(\n `Backend '${this.backendName}' has an internal memory leak ` +\n `(${dataIdsLeaked} data ids) after running '${kernelName}'`);\n }\n }\n\n /**\n * @deprecated Use `runKernel` for newly added kernels. Keep using this method\n * only for kernels that are not yet fully modularized.\n */\n runKernelFunc(\n forwardFunc: ForwardFunc, inputs: I,\n backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},\n kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[],\n outputsToSave?: boolean[]): T {\n let outputs: Tensor[];\n let saved: Tensor[] = [];\n const isTapeOn = this.isTapeOn();\n if (kernelName == null) {\n kernelName =\n this.state.activeScope != null ? this.state.activeScope.name : '';\n }\n\n const startingBytecount = this.state.numBytes;\n const startingNumTensors = this.state.numTensors;\n\n if (this.shouldCheckForMemLeaks()) {\n this.state.numDataMovesStack.push(0);\n }\n\n let kernelFunc: () => Tensor[];\n const kernel = getKernel(kernelName, this.backendName);\n let out: TensorInfo|TensorInfo[];\n if (kernel != null) {\n kernelFunc = () => {\n const numDataIdsBefore = this.backend.numDataIds();\n out = kernel.kernelFunc({inputs, attrs, backend: this.backend});\n const outInfos = Array.isArray(out) ? out : [out];\n if (this.shouldCheckForMemLeaks()) {\n this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);\n }\n const outTensors = outInfos.map(\n ({dataId, shape, dtype}) =>\n this.makeTensorFromDataId(dataId, shape, dtype));\n\n // Save the inputs and outputs.\n // Do not save unless we are recording to the tape. Otherwise it would\n // cause a mem leak since we would never run backprop, which disposes\n // the kept tensors.\n if (isTapeOn) {\n let tensorsToSave =\n this.getTensorsForGradient(kernelName, inputs, outTensors);\n if (tensorsToSave == null) {\n // Fallback for ops that call runKernelFunc and pass in\n // inputsToSave and outputsToSave. Currently this is the set of ops\n // with kernel support in the WASM backend. Once those ops and\n // respective gradients are modularised we can remove this path.\n if (outputsToSave == null) {\n outputsToSave = [];\n }\n const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);\n tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);\n }\n saved = this.saveTensorsForBackwardMode(tensorsToSave);\n }\n return outTensors;\n };\n } else {\n const saveFunc: GradSaveFunc = (tensors) => {\n // Do not save unless we are recording to the tape. Otherwise it would\n // cause a mem leak since we would never run backprop, which disposes\n // the kept tensors.\n if (!isTapeOn) {\n return;\n }\n saved = tensors.map(tensor => this.keep(this.clone(tensor)));\n };\n\n kernelFunc = () => {\n const numDataIdsBefore = this.backend.numDataIds();\n out = this.tidy(() => forwardFunc(this.backend, saveFunc));\n const outs = (Array.isArray(out) ? out : [out]) as Tensor[];\n if (this.shouldCheckForMemLeaks()) {\n this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);\n }\n return outs;\n };\n }\n\n // Stop recording to a tape when running a kernel.\n this.scopedRun(\n () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {\n if (!this.ENV.getBool('DEBUG')) {\n outputs = kernelFunc();\n } else {\n outputs = this.profiler.profileKernel(\n kernelName, inputs, () => kernelFunc());\n }\n });\n\n if (isTapeOn) {\n this.addTapeNode(\n kernelName, inputs, outputs, backwardsFunc, saved, attrs);\n }\n\n if (this.state.profiling) {\n this.state.activeProfile.kernels.push({\n name: kernelName,\n bytesAdded: this.state.numBytes - startingBytecount,\n totalBytesSnapshot: this.state.numBytes,\n tensorsAdded: this.state.numTensors - startingNumTensors,\n totalTensorsSnapshot: this.state.numTensors,\n inputShapes: Object.keys(inputs).map(key => inputs[key].shape),\n outputShapes: outputs.map(item => item.shape)\n });\n }\n return (Array.isArray(out) ? outputs : outputs[0]) as T;\n }\n\n /**\n * Saves tensors used in forward mode for use in backward mode.\n *\n * @param tensors the list of tensors to save.\n */\n private saveTensorsForBackwardMode(tensors: Tensor[]): Tensor[] {\n const saved = tensors.map(tensor => this.keep(this.clone(tensor)));\n return saved;\n }\n\n /**\n * Returns a list of tensors to save for a given gradient calculation.\n *\n * Returns undefined if their is no registered gradient for this kernel in the\n * gradient registry.\n *\n * @param kernelName name of kernel to look up gradient for.\n * @param inputs a map of input tensors.\n * @param outputs an array of output tensors from forward mode of kernel.\n */\n private getTensorsForGradient(\n kernelName: string, inputs: NamedTensorMap,\n outputs: Tensor[]): Tensor[]|null {\n const gradConfig = getGradient(kernelName);\n if (gradConfig != null) {\n const inputsToSave: string[] = gradConfig.inputsToSave || [];\n const outputsToSave: boolean[] = gradConfig.outputsToSave || [];\n\n // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs\n // specified in inputsToSave will be saved.\n let inputTensorsToSave: Tensor[];\n if (gradConfig.saveAllInputs) {\n util.assert(\n Array.isArray(inputs),\n () => 'saveAllInputs is true, expected inputs to be an array.');\n\n inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);\n } else {\n inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);\n }\n\n const outputTensorsToSave: Tensor[] =\n outputs.filter((_, i) => outputsToSave[i]);\n\n return inputTensorsToSave.concat(outputTensorsToSave);\n }\n // TODO(yassogba) throw exception here once all runkernelFunc calls with\n // inputsToSave/outputsToSave are removed\n return null;\n }\n\n /**\n * Internal method used by public APIs for tensor creation. Makes a new\n * tensor with the provided shape, dtype and values. It always\n * creates a new data id and writes the values to the underlying backend.\n */\n makeTensor(\n values: DataValues, shape: number[], dtype: DataType,\n backend?: KernelBackend): Tensor {\n if (values == null) {\n throw new Error('Values passed to engine.makeTensor() are null');\n }\n dtype = dtype || 'float32';\n backend = backend || this.backend;\n let backendVals = values as BackendValues;\n if (dtype === 'string' && util.isString(values[0])) {\n backendVals = (values as string[]).map(d => util.encodeString(d));\n }\n const dataId = backend.write(backendVals, shape, dtype);\n const t = new Tensor(shape, dtype, dataId, this.nextTensorId());\n this.incRef(t, backend);\n\n // Count bytes for string tensors.\n if (dtype === 'string') {\n const info = this.state.tensorInfo.get(dataId);\n const newBytes = bytesFromStringArray(backendVals as Uint8Array[]);\n this.state.numBytes += newBytes - info.bytes;\n info.bytes = newBytes;\n }\n return t;\n }\n\n /**\n * Internal method used by backends. Makes a new tensor\n * that is a wrapper around an existing data id. It doesn't create\n * a new data id, only increments the ref count used in memory tracking.\n */\n makeTensorFromDataId(\n dataId: DataId, shape: number[], dtype: DataType,\n backend?: KernelBackend): Tensor {\n dtype = dtype || 'float32';\n const t = new Tensor(shape, dtype, dataId, this.nextTensorId());\n this.incRef(t, backend);\n return t;\n }\n\n makeVariable(\n initialValue: Tensor, trainable = true, name?: string,\n dtype?: DataType): Variable {\n name = name || this.nextVariableId().toString();\n if (dtype != null && dtype !== initialValue.dtype) {\n initialValue = initialValue.asType(dtype);\n }\n const v = new Variable(initialValue, trainable, name, this.nextTensorId());\n if (this.state.registeredVariables[v.name] != null) {\n throw new Error(`Variable with name ${v.name} was already registered`);\n }\n this.state.registeredVariables[v.name] = v;\n this.incRef(v, this.backend);\n return v;\n }\n\n incRef(a: Tensor, backend: KernelBackend): void {\n const refCount = this.state.tensorInfo.has(a.dataId) ?\n this.state.tensorInfo.get(a.dataId).refCount :\n 0;\n this.state.numTensors++;\n if (a.dtype === 'string') {\n this.state.numStringTensors++;\n }\n if (refCount === 0) {\n this.state.numDataBuffers++;\n\n // Bytes for complex numbers are counted by their components. Bytes for\n // string tensors are counted when writing values.\n let bytes = 0;\n if (a.dtype !== 'complex64' && a.dtype !== 'string') {\n bytes = a.size * util.bytesPerElement(a.dtype);\n }\n this.state.tensorInfo.set(a.dataId, {\n backend: backend || this.backend,\n dtype: a.dtype,\n shape: a.shape,\n bytes,\n refCount: 0\n });\n this.state.numBytes += bytes;\n }\n this.state.tensorInfo.get(a.dataId).refCount++;\n if (!(a instanceof Variable)) {\n this.track(a);\n }\n }\n\n disposeTensor(a: Tensor): void {\n if (!this.state.tensorInfo.has(a.dataId)) {\n return;\n }\n\n this.state.numTensors--;\n if (a.dtype === 'string') {\n this.state.numStringTensors--;\n }\n const info = this.state.tensorInfo.get(a.dataId);\n const refCount = info.refCount;\n if (refCount <= 1) {\n // Don't count bytes for complex numbers as they are counted by their\n // components.\n if (a.dtype !== 'complex64') {\n this.state.numBytes -= info.bytes;\n }\n this.state.numDataBuffers--;\n info.backend.disposeData(a.dataId);\n this.state.tensorInfo.delete(a.dataId);\n } else {\n this.state.tensorInfo.get(a.dataId).refCount--;\n }\n // TODO(nsthorat): Construct an error and save the stack trace for\n // debugging when in debug mode. Creating a stack trace is too expensive\n // to do unconditionally.\n }\n\n disposeVariables(): void {\n for (const varName in this.state.registeredVariables) {\n const v = this.state.registeredVariables[varName];\n this.disposeVariable(v);\n }\n }\n\n disposeVariable(v: Variable): void {\n this.disposeTensor(v);\n if (this.state.registeredVariables[v.name] != null) {\n delete this.state.registeredVariables[v.name];\n }\n }\n\n memory(): MemoryInfo {\n const info = this.backend.memory() as MemoryInfo;\n info.numTensors = this.state.numTensors;\n info.numDataBuffers = this.state.numDataBuffers;\n info.numBytes = this.state.numBytes;\n if (this.state.numStringTensors > 0) {\n info.unreliable = true;\n if (info.reasons == null) {\n info.reasons = [];\n }\n info.reasons.push(\n 'Memory usage by string tensors is approximate ' +\n '(2 bytes per character)');\n }\n return info;\n }\n\n async profile(query: () => TensorContainer): Promise {\n this.state.profiling = true;\n\n const startBytes = this.state.numBytes;\n const startNumTensors = this.state.numTensors;\n\n this.state.activeProfile.kernels = [];\n this.state.activeProfile.result = query();\n\n this.state.profiling = false;\n\n this.state.activeProfile.peakBytes = Math.max(\n ...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));\n this.state.activeProfile.newBytes = this.state.numBytes - startBytes;\n this.state.activeProfile.newTensors =\n this.state.numTensors - startNumTensors;\n return this.state.activeProfile;\n }\n\n isTapeOn(): boolean {\n return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;\n }\n\n private addTapeNode(\n kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],\n gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void {\n const tapeNode: TapeNode =\n {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};\n\n const gradConfig = getGradient(kernelName);\n if (gradConfig != null) {\n gradientsFunc = gradConfig.gradFunc;\n }\n if (gradientsFunc != null) {\n tapeNode.gradient = (dys: Tensor[]) => {\n // TODO(smilkov): To optimize back-prop, pass dys that are not used in\n // the backprop graph to the user as null instead of zeros\n dys = dys.map((dy, i) => {\n if (dy == null) {\n const output = outputs[i];\n const vals = util.makeZerosTypedArray(output.size, output.dtype);\n return this.makeTensor(vals, output.shape, output.dtype);\n }\n return dy;\n });\n // Grad functions of ops with single outputs expect a dy, while ops\n // with multiple outputs expect dys (array of dy).\n return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);\n };\n }\n this.state.activeTape.push(tapeNode);\n }\n\n keep(result: T): T {\n result.kept = true;\n return result;\n }\n\n private startTape() {\n if (this.state.gradientDepth === 0) {\n this.state.activeTape = [];\n }\n this.state.gradientDepth++;\n }\n\n private endTape() {\n this.state.gradientDepth--;\n }\n\n /**\n * Start a scope. Use this with endScope() to achieve the same functionality\n * as scope() without the need for a function closure.\n */\n startScope(name?: string) {\n const scopeInfo: ScopeState = {\n track: [],\n name: 'unnamed scope',\n id: this.state.nextScopeId++\n };\n if (name) {\n scopeInfo.name = name;\n }\n this.state.scopeStack.push(scopeInfo);\n this.state.activeScope = scopeInfo;\n }\n\n /**\n * End a scope. Use this with startScope() to achieve the same functionality\n * as scope() without the need for a function closure.\n */\n endScope(result?: TensorContainer) {\n const tensorsToTrackInParent = getTensorsInContainer(result);\n const tensorsToTrackInParentSet =\n new Set(tensorsToTrackInParent.map(t => t.id));\n\n // Dispose the arrays tracked in this scope.\n for (let i = 0; i < this.state.activeScope.track.length; i++) {\n const tensor = this.state.activeScope.track[i];\n if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {\n tensor.dispose();\n }\n }\n\n const oldScope = this.state.scopeStack.pop();\n this.state.activeScope = this.state.scopeStack.length === 0 ?\n null :\n this.state.scopeStack[this.state.scopeStack.length - 1];\n\n // Track the current result in the parent scope.\n tensorsToTrackInParent.forEach(tensor => {\n // Only track the tensor if was allocated in the inner scope and is not\n // globally kept.\n if (!tensor.kept && tensor.scopeId === oldScope.id) {\n this.track(tensor);\n }\n });\n }\n\n /**\n * Returns gradients of `f` with respect to each of the `xs`. The gradients\n * returned are of the same length as `xs`, but some might be null if `f`\n * was not a function of that `x`. It also takes optional dy to multiply the\n * gradient, which defaults to `1`.\n */\n gradients(\n f: () => T, xs: Tensor[], dy?: T,\n allowNoGradients = false): {value: T, grads: Tensor[]} {\n util.assert(\n xs.length > 0, () => 'gradients() received an empty list of xs.');\n if (dy != null && dy.dtype !== 'float32') {\n throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);\n }\n\n const y = this.scopedRun(\n () => this.startTape(), () => this.endTape(),\n () => this.tidy('forward', f));\n\n util.assert(\n y instanceof Tensor,\n () => 'The result y returned by f() must be a tensor.');\n // Filter out the nodes that don't connect x => y.\n const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);\n if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {\n throw new Error(\n 'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +\n 'that the f you passed encloses all operations that lead from x ' +\n 'to y.');\n }\n\n return this.tidy('backward', () => {\n const accumulatedGradientMap: {[tensorId: number]: Tensor} = {};\n accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;\n\n // Backprop gradients through the filtered nodes.\n backpropagateGradients(\n accumulatedGradientMap, filteredTape,\n // Pass the tidy function to avoid circular dep with `tape.ts`.\n f => this.tidy(f as ScopeFn));\n const grads = xs.map(x => accumulatedGradientMap[x.id]);\n\n if (this.state.gradientDepth === 0) {\n // This means that we are not computing higher-order gradients\n // and can clean up the tape.\n this.state.activeTape.forEach(node => {\n for (const tensor of node.saved) {\n tensor.dispose();\n }\n });\n this.state.activeTape = null;\n }\n return {value: y, grads};\n });\n }\n\n customGrad(f: CustomGradientFunc):\n (...args: Array) => T {\n util.assert(\n util.isFunction(f),\n () => 'The f passed in customGrad(f) must be a function.');\n return (...inputs: Tensor[]): T => {\n util.assert(\n inputs.every(t => t instanceof Tensor),\n () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +\n 'tensors');\n\n let res: {\n value: T,\n gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[],\n };\n const inputMap: NamedTensorMap = {};\n inputs.forEach((input, i) => {\n inputMap[i] = input;\n });\n return this.runKernelFunc(\n (_, save) => {\n res = f(...[...inputs, save]);\n util.assert(\n res.value instanceof Tensor,\n () => 'The function f passed in customGrad(f) must return an ' +\n 'object where `obj.value` is a tensor');\n util.assert(\n util.isFunction(res.gradFunc),\n () => 'The function f passed in customGrad(f) must return an ' +\n 'object where `obj.gradFunc` is a function.');\n return res.value;\n },\n inputMap,\n (dy: T, saved: Tensor[]) => {\n const gradRes = res.gradFunc(dy, saved);\n const grads: Tensor[] =\n Array.isArray(gradRes) ? gradRes : [gradRes];\n util.assert(\n grads.length === inputs.length,\n () => 'The function f passed in customGrad(f) must return an ' +\n 'object where `obj.gradFunc` is a function that returns ' +\n 'the same number of tensors as inputs passed to f(...).');\n util.assert(\n grads.every(t => t instanceof Tensor),\n () => 'The function f passed in customGrad(f) must return an ' +\n 'object where `obj.gradFunc` is a function that returns ' +\n 'a list of only tensors.');\n const gradMap: {[key: string]: () => Tensor} = {};\n grads.forEach((grad, i) => {\n gradMap[i] = () => grad;\n });\n return gradMap;\n });\n };\n }\n\n readSync(dataId: DataId): BackendValues {\n // Route the read to the correct backend.\n const info = this.state.tensorInfo.get(dataId);\n return info.backend.readSync(dataId);\n }\n read(dataId: DataId): Promise {\n // Route the read to the correct backend.\n const info = this.state.tensorInfo.get(dataId);\n return info.backend.read(dataId);\n }\n\n async time(query: () => void): Promise {\n const start = now();\n const timingInfo = await this.backend.time(query) as TimingInfo;\n timingInfo.wallMs = now() - start;\n return timingInfo;\n }\n\n /**\n * Tracks a Tensor in the current scope to be automatically cleaned up\n * when the current scope ends, and returns the value.\n *\n * @param result The Tensor to track in the current scope.\n */\n private track(result: T): T {\n if (this.state.activeScope != null) {\n result.scopeId = this.state.activeScope.id;\n this.state.activeScope.track.push(result);\n }\n\n return result;\n }\n\n get registeredVariables(): NamedVariableMap {\n return this.state.registeredVariables;\n }\n\n /**\n * Resets the engine state. Removes all backends but does not remove\n * registered backend factories.\n */\n reset(): void {\n // Make any pending promise obsolete.\n this.pendingBackendInitId++;\n\n this.state.dispose();\n this.ENV.reset();\n this.state = new EngineState();\n\n for (const backendName in this.registry) {\n this.disposeRegisteredKernels(backendName);\n this.registry[backendName].dispose();\n delete this.registry[backendName];\n }\n this.backendName = null;\n this.backendInstance = null;\n this.pendingBackendInit = null;\n }\n}\n\nfunction ones(shape: number[]): Tensor {\n const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');\n return ENGINE.makeTensor(values, shape, 'float32');\n}\n\nlet GLOBAL: {_tfengine: Engine};\nfunction getGlobalNamespace(): {_tfengine: Engine} {\n if (GLOBAL == null) {\n // tslint:disable-next-line:no-any\n let ns: any;\n if (typeof (window) !== 'undefined') {\n ns = window;\n } else if (typeof (global) !== 'undefined') {\n ns = global;\n } else if (typeof (process) !== 'undefined') {\n ns = process;\n } else if (typeof (self) !== 'undefined') {\n ns = self;\n } else {\n throw new Error('Could not find a global object');\n }\n GLOBAL = ns;\n }\n return GLOBAL;\n}\n\nfunction getOrMakeEngine(): Engine {\n const ns = getGlobalNamespace();\n if (ns._tfengine == null) {\n const environment = new Environment(ns);\n ns._tfengine = new Engine(environment);\n }\n setEnvironmentGlobal(ns._tfengine.ENV);\n\n // Tell the current tensor interface that the global engine is responsible\n // for tracking.\n setTensorTracker(() => ns._tfengine);\n return ns._tfengine;\n}\n\nexport const ENGINE = getOrMakeEngine();\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from './tensor';\nimport {NamedTensorMap} from './tensor_types';\nimport * as util from './util';\n\nexport interface TapeNode {\n id: number;\n kernelName: string;\n outputs: Tensor[];\n inputs: NamedTensorMap;\n // Optional params, defined only for ops with gradient impl.\n gradient?: (dys: Tensor[]) => NamedGradientMap;\n saved?: Tensor[];\n}\n\nexport type NamedGradientMap = {\n [inputName: string]: () => Tensor;\n};\n\n/**\n * Computes a list of TapeNodes that connect x to y, filtering everything else\n * out and preserving the order of the original tape elements.\n *\n * @param tape The tape elements to filter.\n * @param xs The input Tensors.\n * @param y The output Tensor.\n */\nexport function getFilteredNodesXToY(\n tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] {\n // Forward pass to compute all the nodes and Tensors that are transitively a\n // function of x.\n const tensorsFromX: {[tensorId: number]: boolean} = {};\n const nodesFromX: {[nodeId: number]: boolean} = {};\n for (let i = 0; i < xs.length; i++) {\n tensorsFromX[xs[i].id] = true;\n }\n\n for (let i = 0; i < tape.length; i++) {\n const node = tape[i];\n const nodeInputs = node.inputs;\n for (const inputName in nodeInputs) {\n const input = nodeInputs[inputName];\n\n let anyInputFromX = false;\n for (let j = 0; j < xs.length; j++) {\n if (tensorsFromX[input.id]) {\n node.outputs.forEach(output => tensorsFromX[output.id] = true);\n anyInputFromX = true;\n nodesFromX[node.id] = true;\n break;\n }\n }\n\n if (anyInputFromX) {\n break;\n }\n }\n }\n\n // Backward pass to find all of the nodes and Tensors that lead to y.\n const tensorsLeadToY: {[tensorId: number]: boolean} = {};\n tensorsLeadToY[y.id] = true;\n const nodesToY: {[nodeId: number]: boolean} = {};\n\n for (let i = tape.length - 1; i >= 0; i--) {\n const node = tape[i];\n const nodeInputs = node.inputs;\n\n // If any of the outputs lead to y, mark all of the inputs as leading to y.\n for (let j = 0; j < node.outputs.length; j++) {\n if (tensorsLeadToY[node.outputs[j].id]) {\n for (const inputName in nodeInputs) {\n tensorsLeadToY[nodeInputs[inputName].id] = true;\n nodesToY[node.id] = true;\n }\n break;\n }\n }\n }\n\n // Return the paths that come from x and lead to y.\n const filteredTape: TapeNode[] = [];\n for (let i = 0; i < tape.length; i++) {\n const node = tape[i];\n\n if (nodesFromX[node.id] && nodesToY[node.id]) {\n // Prune the inputs from the node that aren't a function of x.\n const prunedInputs: {[inputName: string]: Tensor} = {};\n for (const inputName in node.inputs) {\n const nodeInput = node.inputs[inputName];\n if (tensorsFromX[nodeInput.id]) {\n prunedInputs[inputName] = nodeInput;\n }\n }\n\n // Copy the node and overwrite inputsAndArgs to the pruned version.\n const prunedNode = Object.assign({}, node);\n prunedNode.inputs = prunedInputs;\n prunedNode.outputs = node.outputs;\n\n filteredTape.push(prunedNode);\n }\n }\n\n return filteredTape;\n}\n\n/**\n * Backpropagate gradients through the filtered TapeNodes.\n *\n * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map\n * is mutated by this method.\n * @param filteredTape The filtered TapeNodes to backprop through.\n */\nexport function backpropagateGradients(\n tensorAccumulatedGradientMap: {[tensorId: number]: Tensor},\n filteredTape: TapeNode[], tidy: (f: Function) => Tensor) {\n // Walk the tape backward and keep a map of Tensor to its gradient.\n for (let i = filteredTape.length - 1; i >= 0; i--) {\n const node = filteredTape[i];\n\n const dys: Tensor[] = [];\n node.outputs.forEach(o => {\n const gradTensor = tensorAccumulatedGradientMap[o.id];\n if (gradTensor != null) {\n dys.push(gradTensor);\n } else {\n // This particular output is not in the back-propagation subgraph, so it\n // does not affect the final output, thus we put null for its dy.\n dys.push(null);\n }\n });\n\n if (node.gradient == null) {\n throw new Error(\n `Cannot compute gradient: gradient function not found ` +\n `for ${node.kernelName}.`);\n }\n\n // Backprop dy through this node and accumulate gradients over the inputs.\n const inputGradients = node.gradient(dys);\n\n for (const inputName in node.inputs) {\n if (!(inputName in inputGradients)) {\n throw new Error(\n `Cannot backprop through input ${inputName}. ` +\n `Available gradients found: ${Object.keys(inputGradients)}.`);\n }\n\n // Call the gradient function.\n const dx = tidy(() => inputGradients[inputName]());\n if (dx.dtype !== 'float32') {\n throw new Error(\n `Error in gradient for op ${\n node.kernelName}. The gradient of input ` +\n `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);\n }\n const x = node.inputs[inputName];\n if (!util.arraysEqual(dx.shape, x.shape)) {\n throw new Error(\n `Error in gradient for op ${\n node.kernelName}. The gradient of input ` +\n `'${inputName}' has shape '${dx.shape}', which does not match ` +\n `the shape of the input '${x.shape}'`);\n }\n\n if (tensorAccumulatedGradientMap[x.id] == null) {\n tensorAccumulatedGradientMap[x.id] = dx;\n } else {\n const curGradient = tensorAccumulatedGradientMap[x.id];\n tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);\n curGradient.dispose();\n }\n }\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nexport function isMobile(): boolean {\n // tslint:disable-next-line:no-any\n const a = navigator.userAgent || navigator.vendor || (window as any).opera;\n // tslint:disable-next-line:max-line-length\n return /(android|bb\\d+|meego).+mobile|avantgo|bada\\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i\n .test(a) ||\n // tslint:disable-next-line:max-line-length\n /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\\-(n|u)|c55\\/|capi|ccwa|cdm\\-|cell|chtm|cldc|cmd\\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\\-s|devi|dica|dmob|do(c|p)o|ds(12|\\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\\-|_)|g1 u|g560|gene|gf\\-5|g\\-mo|go(\\.w|od)|gr(ad|un)|haie|hcit|hd\\-(m|p|t)|hei\\-|hi(pt|ta)|hp( i|ip)|hs\\-c|ht(c(\\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\\-(20|go|ma)|i230|iac( |\\-|\\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\\/)|klon|kpt |kwc\\-|kyo(c|k)|le(no|xi)|lg( g|\\/(k|l|u)|50|54|\\-[a-w])|libw|lynx|m1\\-w|m3ga|m50\\/|ma(te|ui|xo)|mc(01|21|ca)|m\\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\\-2|po(ck|rt|se)|prox|psio|pt\\-g|qa\\-a|qc(07|12|21|32|60|\\-[2-7]|i\\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\\-|oo|p\\-)|sdk\\/|se(c(\\-|0|1)|47|mc|nd|ri)|sgh\\-|shar|sie(\\-|m)|sk\\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\\-|v\\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\\-|tdg\\-|tel(i|m)|tim\\-|t\\-mo|to(pl|sh)|ts(70|m\\-|m3|m5)|tx\\-9|up(\\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\\-|your|zeto|zte\\-/i\n .test(a.substr(0, 4));\n}\n\nexport function isBrowser(): boolean {\n return (typeof window !== 'undefined' && window.document != null) ||\n //@ts-ignore\n (typeof WorkerGlobalScope !== 'undefined');\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport * as device_util from './device_util';\nimport {env} from './environment';\n\nconst ENV = env();\n\n/**\n * This file contains environment-related flag registrations.\n */\n\n/** Whether to enable debug mode. */\nENV.registerFlag('DEBUG', () => false, debugValue => {\n if (debugValue) {\n console.warn(\n 'Debugging mode is ON. The output of every math call will ' +\n 'be downloaded to CPU and checked for NaNs. ' +\n 'This significantly impacts performance.');\n }\n});\n\n/** Whether we are in a browser (as versus, say, node.js) environment. */\nENV.registerFlag('IS_BROWSER', () => device_util.isBrowser());\n\n/** Whether we are in a browser (as versus, say, node.js) environment. */\nENV.registerFlag(\n 'IS_NODE',\n () => (typeof process !== 'undefined') &&\n (typeof process.versions !== 'undefined') &&\n (typeof process.versions.node !== 'undefined'));\n\n/** Whether this browser is Chrome. */\nENV.registerFlag(\n 'IS_CHROME',\n () => typeof navigator !== 'undefined' && navigator != null &&\n navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&\n /Google Inc/.test(navigator.vendor));\n\n/**\n * True when the environment is \"production\" where we disable safety checks\n * to gain performance.\n */\nENV.registerFlag('PROD', () => false);\n\n/**\n * Whether to do sanity checks when inferring a shape from user-provided\n * values, used when creating a new tensor.\n */\nENV.registerFlag(\n 'TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV.getBool('DEBUG'));\n\n/** Whether deprecation warnings are enabled. */\nENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true);\n\n/** True if running unit tests. */\nENV.registerFlag('IS_TEST', () => false);\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nconst contexts: {[key: string]: WebGLRenderingContext} = {};\n\nconst WEBGL_ATTRIBUTES: WebGLContextAttributes = {\n alpha: false,\n antialias: false,\n premultipliedAlpha: false,\n preserveDrawingBuffer: false,\n depth: false,\n stencil: false,\n failIfMajorPerformanceCaveat: true\n};\n\nexport function setWebGLContext(\n webGLVersion: number, gl: WebGLRenderingContext) {\n contexts[webGLVersion] = gl;\n}\n\nexport function getWebGLContext(webGLVersion: number): WebGLRenderingContext {\n if (!(webGLVersion in contexts)) {\n contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion);\n }\n const gl = contexts[webGLVersion];\n if (gl.isContextLost()) {\n delete contexts[webGLVersion];\n return getWebGLContext(webGLVersion);\n }\n\n gl.disable(gl.DEPTH_TEST);\n gl.disable(gl.STENCIL_TEST);\n gl.disable(gl.BLEND);\n gl.disable(gl.DITHER);\n gl.disable(gl.POLYGON_OFFSET_FILL);\n gl.disable(gl.SAMPLE_COVERAGE);\n gl.enable(gl.SCISSOR_TEST);\n gl.enable(gl.CULL_FACE);\n gl.cullFace(gl.BACK);\n\n return contexts[webGLVersion];\n}\n\nfunction createCanvas(webGLVersion: number) {\n if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {\n return new OffscreenCanvas(300, 150);\n } else if (typeof document !== 'undefined') {\n return document.createElement('canvas');\n } else {\n throw new Error('Cannot create a canvas in this context');\n }\n}\n\nfunction getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext {\n if (webGLVersion !== 1 && webGLVersion !== 2) {\n throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');\n }\n const canvas = createCanvas(webGLVersion);\n\n canvas.addEventListener('webglcontextlost', (ev: Event) => {\n ev.preventDefault();\n delete contexts[webGLVersion];\n }, false);\n if (webGLVersion === 1) {\n return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||\n canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)) as\n WebGLRenderingContext;\n }\n return canvas.getContext('webgl2', WEBGL_ATTRIBUTES) as WebGLRenderingContext;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../environment';\n\nimport {DataId, Tensor} from '../../tensor';\nimport {BackendValues, DataType} from '../../types';\nimport * as util from '../../util';\n\nexport enum PackingScheme {\n /**\n * All values in a single texel are densely packed without any constraints.\n *\n * This is how the shader encodes a tensor with shape = [2, 3, 4]\n * (indices are [batch, row, col]).\n *\n * 000|001 010|011 020|021\n * ------- ------- -------\n * 002|003 012|013 022|023\n *\n * 100|101 110|111 120|121\n * ------- ------- -------\n * 102|103 112|113 122|123\n *\n */\n DENSE,\n\n /**\n * Single texels contain only values from the same batch, and from adjacent\n * rows and columns.\n *\n * This is how the shader encodes a tensor with shape = [2, 3, 5]\n * (indices are [batch, row, col]).\n *\n * 000|001 002|003 004|xxx 020|021 022|023 024|xxx\n * ------- ------- ------- ------- ------- -------\n * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx\n *\n * 100|101 102|103 104|xxx 120|121 122|123 124|xxx\n * ------- ------- ------- ------- ------- -------\n * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx\n *\n */\n SHARED_BATCH\n}\n\nexport enum TextureUsage {\n RENDER,\n UPLOAD,\n PIXELS,\n DOWNLOAD\n}\n\nexport enum PhysicalTextureType {\n UNPACKED_FLOAT16,\n UNPACKED_FLOAT32,\n PACKED_4X1_UNSIGNED_BYTE,\n PACKED_2X2_FLOAT32,\n PACKED_2X2_FLOAT16\n}\n\nexport interface TextureData {\n // Required.\n shape: number[];\n dtype: DataType;\n\n // Optional.\n values?: BackendValues;\n texture?: WebGLTexture;\n // For complex numbers, the real and imaginary parts are stored as their own\n // individual tensors, with a parent joining the two with the\n // complexTensors field. When this is defined, texture will be null.\n complexTensors?: {real: Tensor, imag: Tensor};\n /** [rows, columns] shape of the texture. */\n texShape?: [number, number];\n usage?: TextureUsage;\n isPacked?: boolean;\n\n // Available when the tensor has been sliced.\n slice?: {\n // Offset in the 'flat index' space.\n flatOffset: number;\n // Used for counting how many sliced tensors point to the same texture.\n origDataId: DataId;\n };\n}\n\nexport function getUnpackedMatrixTextureShapeWidthHeight(\n rows: number, columns: number): [number, number] {\n return [columns, rows];\n}\n\nexport function getUnpackedArraySizeFromMatrixSize(\n matrixSize: number, channelsPerTexture: number): number {\n return matrixSize * channelsPerTexture;\n}\n\nexport function getColorMatrixTextureShapeWidthHeight(\n rows: number, columns: number): [number, number] {\n return [columns * 4, rows];\n}\n\n/**\n * Get shape for densely packed RGBA texture.\n */\nexport function getDenseTexShape(shape: number[]): [number, number] {\n const size = util.sizeFromShape(shape);\n const texelsNeeded = Math.ceil(size / 4);\n return util.sizeToSquarishShape(texelsNeeded);\n}\n\nexport function getMatrixSizeFromUnpackedArraySize(\n unpackedSize: number, channelsPerTexture: number): number {\n if (unpackedSize % channelsPerTexture !== 0) {\n throw new Error(\n `unpackedSize (${unpackedSize}) must be a multiple of ` +\n `${channelsPerTexture}`);\n }\n return unpackedSize / channelsPerTexture;\n}\n\nexport function decodeMatrixFromUnpackedColorRGBAArray(\n unpackedArray: Float32Array, matrix: Float32Array, channels: number) {\n const requiredSize = unpackedArray.length * channels / 4;\n if (matrix.length < requiredSize) {\n throw new Error(\n `matrix length (${matrix.length}) must be >= ${requiredSize}`);\n }\n let dst = 0;\n for (let src = 0; src < unpackedArray.length; src += 4) {\n for (let c = 0; c < channels; c++) {\n matrix[dst++] = unpackedArray[src + c];\n }\n }\n}\n\nexport function getPackedMatrixTextureShapeWidthHeight(\n rows: number, columns: number): [number, number] {\n return [\n Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))\n ];\n}\n\nexport function getPackedRGBAArraySizeFromMatrixShape(\n rows: number, columns: number): number {\n const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);\n return w * h * 4;\n}\n\nexport interface TextureConfig {\n internalFormatFloat: number;\n textureFormatFloat: number;\n internalFormatPackedHalfFloat: number;\n internalFormatHalfFloat: number;\n internalFormatPackedFloat: number;\n\n // The format to use during a gl.readPixels call.\n downloadTextureFormat: number;\n // How many channels need to be unpacked after a gl.readPixels call.\n downloadUnpackNumChannels: number;\n\n defaultNumChannels: number;\n textureTypeHalfFloat: number;\n textureTypeFloat: number;\n}\n\nexport function getTextureConfig(\n // tslint:disable-next-line:no-any\n gl: WebGLRenderingContext, textureHalfFloatExtension?: any): TextureConfig {\n // tslint:disable-next-line:no-any\n const glany = gl as any;\n\n let internalFormatFloat: number;\n let internalFormatHalfFloat: number;\n let internalFormatPackedHalfFloat: number;\n let internalFormatPackedFloat: number;\n let textureFormatFloat: number;\n\n let downloadTextureFormat: number;\n let downloadUnpackNumChannels: number;\n\n let defaultNumChannels: number;\n let textureTypeHalfFloat: number;\n let textureTypeFloat: number;\n\n if (env().getNumber('WEBGL_VERSION') === 2) {\n internalFormatFloat = glany.R32F;\n internalFormatHalfFloat = glany.R16F;\n internalFormatPackedHalfFloat = glany.RGBA16F;\n internalFormatPackedFloat = glany.RGBA32F;\n textureFormatFloat = glany.RED;\n downloadUnpackNumChannels = 4;\n defaultNumChannels = 1;\n textureTypeHalfFloat = glany.HALF_FLOAT;\n textureTypeFloat = glany.FLOAT;\n } else {\n internalFormatFloat = gl.RGBA;\n internalFormatHalfFloat = gl.RGBA;\n internalFormatPackedHalfFloat = gl.RGBA;\n internalFormatPackedFloat = glany.RGBA;\n textureFormatFloat = gl.RGBA;\n downloadUnpackNumChannels = 4;\n defaultNumChannels = 4;\n textureTypeHalfFloat = textureHalfFloatExtension != null ?\n textureHalfFloatExtension.HALF_FLOAT_OES :\n null;\n textureTypeFloat = gl.FLOAT;\n }\n downloadTextureFormat = gl.RGBA;\n\n return {\n internalFormatFloat,\n internalFormatHalfFloat,\n internalFormatPackedHalfFloat,\n internalFormatPackedFloat,\n textureFormatFloat,\n downloadTextureFormat,\n downloadUnpackNumChannels,\n defaultNumChannels,\n textureTypeHalfFloat,\n textureTypeFloat\n };\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../environment';\n\nimport * as util from '../../util';\n\nimport {getWebGLContext} from './canvas_util';\nimport {getTextureConfig} from './tex_util';\n\nexport function callAndCheck(\n gl: WebGLRenderingContext, debugMode: boolean, func: () => T): T {\n const returnValue = func();\n if (debugMode) {\n checkWebGLError(gl);\n }\n return returnValue;\n}\n\nfunction checkWebGLError(gl: WebGLRenderingContext) {\n const error = gl.getError();\n if (error !== gl.NO_ERROR) {\n throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));\n }\n}\n\n// https://en.wikipedia.org/wiki/Half-precision_floating-point_format\nconst MIN_FLOAT16 = 5.96e-8;\nconst MAX_FLOAT16 = 65504;\n\nexport function canBeRepresented(num: number): boolean {\n if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||\n (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {\n return true;\n }\n return false;\n}\n\nexport function getWebGLErrorMessage(\n gl: WebGLRenderingContext, status: number): string {\n switch (status) {\n case gl.NO_ERROR:\n return 'NO_ERROR';\n case gl.INVALID_ENUM:\n return 'INVALID_ENUM';\n case gl.INVALID_VALUE:\n return 'INVALID_VALUE';\n case gl.INVALID_OPERATION:\n return 'INVALID_OPERATION';\n case gl.INVALID_FRAMEBUFFER_OPERATION:\n return 'INVALID_FRAMEBUFFER_OPERATION';\n case gl.OUT_OF_MEMORY:\n return 'OUT_OF_MEMORY';\n case gl.CONTEXT_LOST_WEBGL:\n return 'CONTEXT_LOST_WEBGL';\n default:\n return `Unknown error code ${status}`;\n }\n}\n\nexport function getExtensionOrThrow(\n gl: WebGLRenderingContext, debug: boolean, extensionName: string): {} {\n return throwIfNull<{}>(\n gl, debug, () => gl.getExtension(extensionName),\n 'Extension \"' + extensionName + '\" not supported on this browser.');\n}\n\nexport function createVertexShader(\n gl: WebGLRenderingContext, debug: boolean,\n vertexShaderSource: string): WebGLShader {\n const vertexShader: WebGLShader = throwIfNull(\n gl, debug, () => gl.createShader(gl.VERTEX_SHADER),\n 'Unable to create vertex WebGLShader.');\n callAndCheck(\n gl, debug, () => gl.shaderSource(vertexShader, vertexShaderSource));\n callAndCheck(gl, debug, () => gl.compileShader(vertexShader));\n if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {\n console.log(gl.getShaderInfoLog(vertexShader));\n throw new Error('Failed to compile vertex shader.');\n }\n return vertexShader;\n}\n\nexport function createFragmentShader(\n gl: WebGLRenderingContext, debug: boolean,\n fragmentShaderSource: string): WebGLShader {\n const fragmentShader: WebGLShader = throwIfNull(\n gl, debug, () => gl.createShader(gl.FRAGMENT_SHADER),\n 'Unable to create fragment WebGLShader.');\n callAndCheck(\n gl, debug, () => gl.shaderSource(fragmentShader, fragmentShaderSource));\n callAndCheck(gl, debug, () => gl.compileShader(fragmentShader));\n if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {\n logShaderSourceAndInfoLog(\n fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));\n throw new Error('Failed to compile fragment shader.');\n }\n return fragmentShader;\n}\n\nconst lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;\nfunction logShaderSourceAndInfoLog(\n shaderSource: string, shaderInfoLog: string) {\n const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);\n if (lineNumberRegexResult == null) {\n console.log(`Couldn't parse line number in error: ${shaderInfoLog}`);\n console.log(shaderSource);\n return;\n }\n\n const lineNumber = +lineNumberRegexResult[1];\n\n const shaderLines = shaderSource.split('\\n');\n const pad = shaderLines.length.toString().length + 2;\n const linesWithLineNumbers = shaderLines.map(\n (line, lineNumber) =>\n util.rightPad((lineNumber + 1).toString(), pad) + line);\n let maxLineLength = 0;\n for (let i = 0; i < linesWithLineNumbers.length; i++) {\n maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);\n }\n\n const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);\n const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);\n const afterErrorLines = linesWithLineNumbers.slice(lineNumber);\n\n console.log(beforeErrorLines.join('\\n'));\n console.log(shaderInfoLog.split('\\n')[0]);\n console.log(\n `%c ${util.rightPad(errorLine[0], maxLineLength)}`,\n 'border:1px solid red; background-color:#e3d2d2; color:#a61717');\n console.log(afterErrorLines.join('\\n'));\n}\n\nexport function createProgram(\n gl: WebGLRenderingContext, debug: boolean): WebGLProgram {\n return throwIfNull(\n gl, debug, () => gl.createProgram(), 'Unable to create WebGLProgram.');\n}\n\nexport function linkProgram(\n gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram) {\n callAndCheck(gl, debug, () => gl.linkProgram(program));\n if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {\n console.log(gl.getProgramInfoLog(program));\n throw new Error('Failed to link vertex and fragment shaders.');\n }\n}\n\nexport function validateProgram(\n gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram) {\n callAndCheck(gl, debug, () => gl.validateProgram(program));\n if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {\n console.log(gl.getProgramInfoLog(program));\n throw new Error('Shader program validation failed.');\n }\n}\n\nexport function createStaticVertexBuffer(\n gl: WebGLRenderingContext, debug: boolean,\n data: Float32Array): WebGLBuffer {\n const buffer: WebGLBuffer = throwIfNull(\n gl, debug, () => gl.createBuffer(), 'Unable to create WebGLBuffer');\n callAndCheck(gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));\n callAndCheck(\n gl, debug, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW));\n return buffer;\n}\n\nexport function createStaticIndexBuffer(\n gl: WebGLRenderingContext, debug: boolean, data: Uint16Array): WebGLBuffer {\n const buffer: WebGLBuffer = throwIfNull(\n gl, debug, () => gl.createBuffer(), 'Unable to create WebGLBuffer');\n callAndCheck(gl, debug, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer));\n callAndCheck(\n gl, debug,\n () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW));\n return buffer;\n}\n\nexport function getNumChannels(): number {\n if (env().getNumber('WEBGL_VERSION') === 2) {\n return 1;\n }\n return 4;\n}\n\nexport function createTexture(\n gl: WebGLRenderingContext, debug: boolean): WebGLTexture {\n return throwIfNull(\n gl, debug, () => gl.createTexture(), 'Unable to create WebGLTexture.');\n}\n\nexport function validateTextureSize(width: number, height: number) {\n const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');\n if ((width <= 0) || (height <= 0)) {\n const requested = `[${width}x${height}]`;\n throw new Error('Requested texture size ' + requested + ' is invalid.');\n }\n if ((width > maxTextureSize) || (height > maxTextureSize)) {\n const requested = `[${width}x${height}]`;\n const max = `[${maxTextureSize}x${maxTextureSize}]`;\n throw new Error(\n 'Requested texture size ' + requested +\n ' greater than WebGL maximum on this browser / GPU ' + max + '.');\n }\n}\n\nexport function createFramebuffer(\n gl: WebGLRenderingContext, debug: boolean): WebGLFramebuffer {\n return throwIfNull(\n gl, debug, () => gl.createFramebuffer(),\n 'Unable to create WebGLFramebuffer.');\n}\n\nexport function bindVertexBufferToProgramAttribute(\n gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram,\n attribute: string, buffer: WebGLBuffer, arrayEntriesPerItem: number,\n itemStrideInBytes: number, itemOffsetInBytes: number): boolean {\n const loc = gl.getAttribLocation(program, attribute);\n if (loc === -1) {\n // The GPU compiler decided to strip out this attribute because it's unused,\n // thus no need to bind.\n return false;\n }\n callAndCheck(gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));\n callAndCheck(\n gl, debug,\n () => gl.vertexAttribPointer(\n loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes,\n itemOffsetInBytes));\n callAndCheck(gl, debug, () => gl.enableVertexAttribArray(loc));\n return true;\n}\n\nexport function bindTextureUnit(\n gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture,\n textureUnit: number) {\n validateTextureUnit(gl, textureUnit);\n callAndCheck(gl, debug, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));\n callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture));\n}\n\nexport function unbindTextureUnit(\n gl: WebGLRenderingContext, debug: boolean, textureUnit: number) {\n validateTextureUnit(gl, textureUnit);\n callAndCheck(gl, debug, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));\n callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null));\n}\n\nexport function getProgramUniformLocationOrThrow(\n gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram,\n uniformName: string): WebGLUniformLocation {\n return throwIfNull(\n gl, debug, () => gl.getUniformLocation(program, uniformName),\n 'uniform \"' + uniformName + '\" not present in program.');\n}\n\nexport function getProgramUniformLocation(\n gl: WebGLRenderingContext, program: WebGLProgram,\n uniformName: string): WebGLUniformLocation {\n return gl.getUniformLocation(program, uniformName);\n}\n\nexport function bindTextureToProgramUniformSampler(\n gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram,\n texture: WebGLTexture, uniformSamplerLocation: WebGLUniformLocation,\n textureUnit: number) {\n callAndCheck(\n gl, debug, () => bindTextureUnit(gl, debug, texture, textureUnit));\n callAndCheck(\n gl, debug, () => gl.uniform1i(uniformSamplerLocation, textureUnit));\n}\n\nexport function bindCanvasToFramebuffer(\n gl: WebGLRenderingContext, debug: boolean) {\n callAndCheck(gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));\n callAndCheck(\n gl, debug, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height));\n callAndCheck(\n gl, debug, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height));\n}\n\nexport function bindColorTextureToFramebuffer(\n gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture,\n framebuffer: WebGLFramebuffer) {\n callAndCheck(\n gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));\n callAndCheck(\n gl, debug,\n () => gl.framebufferTexture2D(\n gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0));\n}\n\nexport function unbindColorTextureFromFramebuffer(\n gl: WebGLRenderingContext, debug: boolean, framebuffer: WebGLFramebuffer) {\n callAndCheck(\n gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));\n callAndCheck(\n gl, debug,\n () => gl.framebufferTexture2D(\n gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0));\n}\n\nexport function validateFramebuffer(gl: WebGLRenderingContext) {\n const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);\n if (status !== gl.FRAMEBUFFER_COMPLETE) {\n throw new Error(\n 'Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));\n }\n}\n\nexport function getFramebufferErrorMessage(\n gl: WebGLRenderingContext, status: number): string {\n switch (status) {\n case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:\n return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';\n case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:\n return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';\n case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:\n return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';\n case gl.FRAMEBUFFER_UNSUPPORTED:\n return 'FRAMEBUFFER_UNSUPPORTED';\n default:\n return `unknown error ${status}`;\n }\n}\n\nfunction throwIfNull(\n gl: WebGLRenderingContext, debug: boolean, returnTOrNull: () => T | null,\n failureMessage: string): T {\n const tOrNull: T|null = callAndCheck(gl, debug, () => returnTOrNull());\n if (tOrNull == null) {\n throw new Error(failureMessage);\n }\n return tOrNull;\n}\n\nfunction validateTextureUnit(gl: WebGLRenderingContext, textureUnit: number) {\n const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;\n const glTextureUnit = textureUnit + gl.TEXTURE0;\n if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {\n const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`;\n throw new Error(`textureUnit must be in ${textureUnitRange}.`);\n }\n}\n\nexport function getBatchDim(shape: number[], dimsToSkip = 2): number {\n return util.sizeFromShape(shape.slice(0, shape.length - dimsToSkip));\n}\n\nexport function getRowsCols(shape: number[]): [number, number] {\n if (shape.length === 0) {\n throw Error('Cannot get rows and columns of an empty shape array.');\n }\n\n return [\n shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]\n ];\n}\n\nexport function getShapeAs3D(shape: number[]): [number, number, number] {\n let shapeAs3D: [number, number, number] = [1, 1, 1];\n const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);\n if (!isScalar) {\n shapeAs3D =\n [getBatchDim(shape), ...getRowsCols(shape)] as [number, number, number];\n }\n return shapeAs3D;\n}\n\nexport function getTextureShapeFromLogicalShape(\n logShape: number[], isPacked = false): [number, number] {\n let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');\n if (isPacked) {\n maxTexSize = maxTexSize * 2;\n\n // This logic ensures we accurately count the number of packed texels needed\n // to accommodate the tensor. We can only pack values in the same texel if\n // they are from adjacent pairs of rows/cols within the same batch. So if a\n // tensor has 3 rows, we pretend it has 4 rows in order to account for the\n // fact that the texels containing the third row are half empty.\n logShape = logShape.map(\n (d, i) => i >= logShape.length - 2 ?\n util.nearestLargerEven(logShape[i]) :\n logShape[i]);\n\n // Packed texture height is at least 2 (the channel height of a single\n // texel).\n if (logShape.length === 1) {\n logShape = [2, logShape[0]];\n }\n }\n\n // If logical shape is 2, we don't squeeze, since we want to match physical.\n if (logShape.length !== 2) {\n const squeezeResult = util.squeezeShape(logShape);\n logShape = squeezeResult.newShape;\n }\n\n let size = util.sizeFromShape(logShape);\n if (logShape.length <= 1 && size <= maxTexSize) {\n return [1, size];\n } else if (\n logShape.length === 2 && logShape[0] <= maxTexSize &&\n logShape[1] <= maxTexSize) {\n return logShape as [number, number];\n } else if (\n logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&\n logShape[2] <= maxTexSize) {\n return [logShape[0] * logShape[1], logShape[2]];\n } else if (\n logShape.length === 3 && logShape[0] <= maxTexSize &&\n logShape[1] * logShape[2] <= maxTexSize) {\n return [logShape[0], logShape[1] * logShape[2]];\n } else if (\n logShape.length === 4 &&\n logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&\n logShape[3] <= maxTexSize) {\n return [logShape[0] * logShape[1] * logShape[2], logShape[3]];\n } else if (\n logShape.length === 4 && logShape[0] <= maxTexSize &&\n logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {\n return [logShape[0], logShape[1] * logShape[2] * logShape[3]];\n } else {\n if (isPacked) {\n // For packed textures size equals the number of channels required to\n // accommodate the texture data. However in order to squarify such that\n // inner dimensions stay even, we rewrite size to equal the number of\n // texels. Then in the return statement we rehydrate the squarified\n // dimensions to channel units.\n\n const batchDim = getBatchDim(logShape);\n let rows = 2, cols = 2;\n if (logShape.length) {\n [rows, cols] = getRowsCols(logShape);\n }\n size = batchDim * (rows / 2) * (cols / 2);\n return util.sizeToSquarishShape(size).map(d => d * 2) as [number, number];\n }\n return util.sizeToSquarishShape(size);\n }\n}\n\nfunction isEven(n: number): boolean {\n return n % 2 === 0;\n}\n\n/**\n * This determines whether reshaping a packed texture requires rearranging\n * the data within the texture, assuming 2x2 packing.\n */\nexport function isReshapeFree(shape1: number[], shape2: number[]): boolean {\n shape1 = shape1.slice(-2);\n shape2 = shape2.slice(-2);\n\n if (util.arraysEqual(shape1, shape2)) {\n return true;\n }\n\n if (!shape1.length || !shape2.length) { // One of the shapes is a scalar.\n return true;\n }\n\n if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||\n shape2[1] === 0) {\n return true;\n }\n\n if (shape1.length !== shape2.length) { // One of the shapes is a vector.\n const shape1Cols = shape1.slice(-1)[0];\n const shape2Cols = shape2.slice(-1)[0];\n if (shape1Cols === shape2Cols) {\n return true;\n }\n\n if (isEven(shape1Cols) && isEven(shape2Cols) &&\n (shape1[0] === 1 || shape2[0] === 1)) {\n return true;\n }\n }\n return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);\n}\n\n// We cache webgl params because the environment gets reset between\n// unit tests and we don't want to constantly query the WebGLContext for\n// MAX_TEXTURE_SIZE.\nlet MAX_TEXTURE_SIZE: number;\nlet MAX_TEXTURES_IN_SHADER: number;\n\nexport function getWebGLMaxTextureSize(webGLVersion: number): number {\n if (MAX_TEXTURE_SIZE == null) {\n const gl = getWebGLContext(webGLVersion);\n MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);\n }\n return MAX_TEXTURE_SIZE;\n}\n\nexport function resetMaxTextureSize() {\n MAX_TEXTURE_SIZE = null;\n}\nexport function resetMaxTexturesInShader() {\n MAX_TEXTURES_IN_SHADER = null;\n}\n\nexport function getMaxTexturesInShader(webGLVersion: number): number {\n if (MAX_TEXTURES_IN_SHADER == null) {\n const gl = getWebGLContext(webGLVersion);\n MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);\n }\n // We cap at 16 to avoid spurious runtime \"memory exhausted\" error.\n return Math.min(16, MAX_TEXTURES_IN_SHADER);\n}\n\nexport function getWebGLDisjointQueryTimerVersion(webGLVersion: number):\n number {\n if (webGLVersion === 0) {\n return 0;\n }\n\n let queryTimerVersion: number;\n const gl = getWebGLContext(webGLVersion);\n\n if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&\n webGLVersion === 2) {\n queryTimerVersion = 2;\n } else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {\n queryTimerVersion = 1;\n } else {\n queryTimerVersion = 0;\n }\n return queryTimerVersion;\n}\n\nexport function hasExtension(gl: WebGLRenderingContext, extensionName: string) {\n const ext = gl.getExtension(extensionName);\n return ext != null;\n}\n\nexport function isWebGLVersionEnabled(webGLVersion: 1|2) {\n try {\n const gl = getWebGLContext(webGLVersion);\n if (gl != null) {\n return true;\n }\n } catch (e) {\n return false;\n }\n return false;\n}\n\nexport function isCapableOfRenderingToFloatTexture(webGLVersion: number):\n boolean {\n if (webGLVersion === 0) {\n return false;\n }\n\n const gl = getWebGLContext(webGLVersion);\n\n if (webGLVersion === 1) {\n if (!hasExtension(gl, 'OES_texture_float')) {\n return false;\n }\n } else {\n if (!hasExtension(gl, 'EXT_color_buffer_float')) {\n return false;\n }\n }\n\n const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);\n return isFrameBufferComplete;\n}\n\n/**\n * Check if we can download values from a float/half-float texture.\n *\n * Note that for performance reasons we use binding a texture to a framebuffer\n * as a proxy for ability to download float values later using readPixels. The\n * texture params of this texture will not match those in readPixels exactly\n * but if we are unable to bind some kind of float texture to the frameBuffer\n * then we definitely will not be able to read float values from it.\n */\nexport function isDownloadFloatTextureEnabled(webGLVersion: number): boolean {\n if (webGLVersion === 0) {\n return false;\n }\n\n const gl = getWebGLContext(webGLVersion);\n\n if (webGLVersion === 1) {\n if (!hasExtension(gl, 'OES_texture_float')) {\n return false;\n }\n if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {\n return false;\n }\n } else {\n if (hasExtension(gl, 'EXT_color_buffer_float')) {\n return createFloatTextureAndBindToFramebuffer(gl);\n }\n\n const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';\n if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {\n const textureHalfFloatExtension =\n gl.getExtension(COLOR_BUFFER_HALF_FLOAT);\n return createHalfFloatTextureAndBindToFramebuffer(\n gl, textureHalfFloatExtension);\n }\n\n return false;\n }\n\n const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);\n return isFrameBufferComplete;\n}\n\nfunction createFloatTextureAndBindToFramebuffer(gl: WebGLRenderingContext):\n boolean {\n const texConfig = getTextureConfig(gl);\n\n const texture = gl.createTexture();\n gl.bindTexture(gl.TEXTURE_2D, texture);\n\n const width = 1;\n const height = 1;\n gl.texImage2D(\n gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0,\n texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);\n\n const frameBuffer = gl.createFramebuffer();\n gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);\n gl.framebufferTexture2D(\n gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);\n\n const isFrameBufferComplete =\n gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;\n\n gl.bindTexture(gl.TEXTURE_2D, null);\n gl.bindFramebuffer(gl.FRAMEBUFFER, null);\n gl.deleteTexture(texture);\n gl.deleteFramebuffer(frameBuffer);\n\n return isFrameBufferComplete;\n}\n\nfunction createHalfFloatTextureAndBindToFramebuffer(\n // tslint:disable-next-line:no-any\n gl: WebGLRenderingContext, textureHalfFloatExtension: any): boolean {\n const texConfig = getTextureConfig(gl, textureHalfFloatExtension);\n const texture = gl.createTexture();\n gl.bindTexture(gl.TEXTURE_2D, texture);\n\n const width = 1;\n const height = 1;\n gl.texImage2D(\n gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0,\n texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);\n\n const frameBuffer = gl.createFramebuffer();\n gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);\n gl.framebufferTexture2D(\n gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);\n\n const isFrameBufferComplete =\n gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;\n\n gl.bindTexture(gl.TEXTURE_2D, null);\n gl.bindFramebuffer(gl.FRAMEBUFFER, null);\n gl.deleteTexture(texture);\n gl.deleteFramebuffer(frameBuffer);\n\n return isFrameBufferComplete;\n}\n\nexport function isWebGLFenceEnabled(webGLVersion: number) {\n if (webGLVersion !== 2) {\n return false;\n }\n const gl = getWebGLContext(webGLVersion);\n\n // tslint:disable-next-line:no-any\n const isEnabled = (gl as any).fenceSync != null;\n return isEnabled;\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as device_util from '../../device_util';\nimport {env} from '../../environment';\n\nimport * as webgl_util from './webgl_util';\n\nconst ENV = env();\n\n/**\n * This file contains WebGL-specific flag registrations.\n */\n\n/**\n * True if WebGL is supported.\n */\nENV.registerFlag('HAS_WEBGL', () => ENV.getNumber('WEBGL_VERSION') > 0);\n\n/** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */\nENV.registerFlag('WEBGL_VERSION', () => {\n if (webgl_util.isWebGLVersionEnabled(2)) {\n return 2;\n } else if (webgl_util.isWebGLVersionEnabled(1)) {\n return 1;\n }\n return 0;\n});\n\nENV.registerFlag(\n 'WEBGL_BUFFER_SUPPORTED', () => ENV.get('WEBGL_VERSION') === 2);\n\n/** Whether the WebGL backend will sometimes forward ops to the CPU. */\nENV.registerFlag('WEBGL_CPU_FORWARD', () => true);\n\n/** Whether the WebGL backend will always use f16 textures for rendering. */\nENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false);\n\n/** Whether to turn all packing related flags on. */\nENV.registerFlag('WEBGL_PACK', () => ENV.getBool('HAS_WEBGL'));\n\n/** Whether we will pack the batchnormalization op. */\nENV.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will pack the clip op. */\nENV.registerFlag('WEBGL_PACK_CLIP', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will pack the depthwise conv op. */\n// TODO: https://github.com/tensorflow/tfjs/issues/1679\nENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => false);\n\n/** Whether we will pack binary ops. */\nENV.registerFlag(\n 'WEBGL_PACK_BINARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will pack unary ops. */\nENV.registerFlag(\n 'WEBGL_PACK_UNARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will pack array ops. */\nENV.registerFlag(\n 'WEBGL_PACK_ARRAY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will pack image ops. */\nENV.registerFlag(\n 'WEBGL_PACK_IMAGE_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will pack reduce ops. */\nENV.registerFlag('WEBGL_PACK_REDUCE', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether packed WebGL kernels lazily unpack their outputs. */\nENV.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV.getBool('WEBGL_PACK'));\n\n/** Whether we will use the im2col algorithm to speed up convolutions. */\nENV.registerFlag('WEBGL_CONV_IM2COL', () => ENV.getBool('WEBGL_PACK'));\n\n/** The maximum texture dimension. */\nENV.registerFlag(\n 'WEBGL_MAX_TEXTURE_SIZE',\n () => webgl_util.getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION')));\n\n/** The maximum texture dimension. */\nENV.registerFlag(\n 'WEBGL_MAX_TEXTURES_IN_SHADER',\n () => webgl_util.getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION')));\n\n/**\n * The disjoint_query_timer extension version.\n * 0: disabled, 1: EXT_disjoint_timer_query, 2:\n * EXT_disjoint_timer_query_webgl2.\n * In Firefox with WebGL 2.0,\n * EXT_disjoint_timer_query_webgl2 is not available, so we must use the\n * WebGL 1.0 extension.\n */\nENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => {\n const webGLVersion = ENV.getNumber('WEBGL_VERSION');\n\n if (webGLVersion === 0) {\n return 0;\n }\n return webgl_util.getWebGLDisjointQueryTimerVersion(webGLVersion);\n});\n\n/**\n * Whether the timer object from the disjoint_query_timer extension gives\n * timing information that is reliable.\n */\nENV.registerFlag(\n 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE',\n () => ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&\n !device_util.isMobile());\n\n/**\n * Whether the device is physically capable of rendering to float32 textures.\n */\nENV.registerFlag(\n 'WEBGL_RENDER_FLOAT32_CAPABLE',\n () => webgl_util.isCapableOfRenderingToFloatTexture(\n ENV.getNumber('WEBGL_VERSION')));\n\n/**\n * Whether rendering to float32 textures is enabled. If disabled, renders to\n * float16 textures.\n */\nENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => {\n return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ?\n false :\n ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');\n});\n\n/**\n * Whether downloading float textures is enabled (16 or 32 bit). If disabled,\n * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.\n */\nENV.registerFlag(\n 'WEBGL_DOWNLOAD_FLOAT_ENABLED',\n () => webgl_util.isDownloadFloatTextureEnabled(\n ENV.getNumber('WEBGL_VERSION')));\n\n/** Whether the fence API is available. */\nENV.registerFlag(\n 'WEBGL_FENCE_API_ENABLED',\n () => webgl_util.isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION')));\n\n/**\n * Tensors with size <= than this will be uploaded as uniforms, not textures.\n */\nENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => {\n // Use uniform uploads only when 32bit floats are supported. In\n // 16bit\n // environments there are problems with comparing a 16bit texture value\n // with a 32bit uniform value.\n const useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED');\n return useUniforms ? 4 : 0;\n});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {KernelBackend} from './backends/backend';\nimport {ENGINE, Engine, MemoryInfo, ProfileInfo, ScopeFn, TimingInfo} from './engine';\nimport {env} from './environment';\n\nimport {Platform} from './platforms/platform';\nimport {setDeprecationWarningFn, Tensor} from './tensor';\nimport {TensorContainer} from './tensor_types';\nimport {getTensorsInContainer} from './tensor_util';\n\n/**\n * Enables production mode which disables correctness checks in favor of\n * performance.\n */\n/** @doc {heading: 'Environment'} */\nexport function enableProdMode(): void {\n env().set('PROD', true);\n}\n\n/**\n * Enables debug mode which will log information about all executed kernels:\n * the elapsed time of the kernel execution, as well as the rank, shape, and\n * size of the output tensor.\n *\n * Debug mode will significantly slow down your application as it will\n * download the result of every operation to the CPU. This should not be used in\n * production. Debug mode does not affect the timing information of the kernel\n * execution as we do not measure download time in the kernel execution time.\n *\n * See also: `tf.profile`, `tf.memory`.\n */\n/** @doc {heading: 'Environment'} */\nexport function enableDebugMode(): void {\n env().set('DEBUG', true);\n}\n\n/** Globally disables deprecation warnings */\nexport function disableDeprecationWarnings(): void {\n env().set('DEPRECATION_WARNINGS_ENABLED', false);\n console.warn(`TensorFlow.js deprecation warnings have been disabled.`);\n}\n\n/** Warn users about deprecated functionality. */\nexport function deprecationWarn(msg: string) {\n if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {\n console.warn(\n msg + ' You can disable deprecation warnings with ' +\n 'tf.disableDeprecationWarnings().');\n }\n}\nsetDeprecationWarningFn(deprecationWarn);\n\n/**\n * Dispose all variables kept in backend engine.\n */\n/** @doc {heading: 'Environment'} */\nexport function disposeVariables(): void {\n ENGINE.disposeVariables();\n}\n\n/**\n * It returns the global engine that keeps track of all tensors and backends.\n */\n/** @doc {heading: 'Environment'} */\nexport function engine(): Engine {\n return ENGINE;\n}\n\n/**\n * Returns memory info at the current time in the program. The result is an\n * object with the following properties:\n *\n * - `numBytes`: Number of bytes allocated (undisposed) at this time.\n * - `numTensors`: Number of unique tensors allocated.\n * - `numDataBuffers`: Number of unique data buffers allocated\n * (undisposed) at this time, which is ≤ the number of tensors\n * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same\n * data buffer with `a`).\n * - `unreliable`: True if the memory usage is unreliable. See `reasons` when\n * `unreliable` is true.\n * - `reasons`: `string[]`, reasons why the memory is unreliable, present if\n * `unreliable` is true.\n *\n * WebGL Properties:\n * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at\n * this time.\n */\n/** @doc {heading: 'Performance', subheading: 'Memory'} */\nexport function memory(): MemoryInfo {\n return ENGINE.memory();\n}\n\n/**\n * Executes the provided function `f()` and returns a promise that resolves\n * with information about the function's memory use:\n * - `newBytes`: the number of new bytes allocated\n * - `newTensors`: the number of new tensors created\n * - `peakBytes`: the peak number of bytes allocated\n * - `kernels`: an array of objects for each kernel involved that reports\n * their input and output shapes, number of bytes used, and number of new\n * tensors created.\n *\n * ```js\n * const profile = await tf.profile(() => {\n * const x = tf.tensor1d([1, 2, 3]);\n * let x2 = x.square();\n * x2.dispose();\n * x2 = x.square();\n * x2.dispose();\n * return x;\n * });\n *\n * console.log(`newBytes: ${profile.newBytes}`);\n * console.log(`newTensors: ${profile.newTensors}`);\n * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>\n * k.totalBytesSnapshot)}`);\n * ```\n *\n */\n/** @doc {heading: 'Performance', subheading: 'Profile'} */\nexport function profile(f: () => TensorContainer): Promise {\n return ENGINE.profile(f);\n}\n\n/**\n * Executes the provided function `fn` and after it is executed, cleans up all\n * intermediate tensors allocated by `fn` except those returned by `fn`.\n * `fn` must not return a Promise (async functions not allowed). The returned\n * result can be a complex object.\n *\n * Using this method helps avoid memory leaks. In general, wrap calls to\n * operations in `tf.tidy` for automatic memory cleanup.\n *\n * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to\n * dispose variables, please use `tf.disposeVariables` or call dispose()\n * directly on variables.\n *\n * ```js\n * // y = 2 ^ 2 + 1\n * const y = tf.tidy(() => {\n * // a, b, and one will be cleaned up when the tidy ends.\n * const one = tf.scalar(1);\n * const a = tf.scalar(2);\n * const b = a.square();\n *\n * console.log('numTensors (in tidy): ' + tf.memory().numTensors);\n *\n * // The value returned inside the tidy function will return\n * // through the tidy, in this case to the variable y.\n * return b.add(one);\n * });\n *\n * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);\n * y.print();\n * ```\n *\n * @param nameOrFn The name of the closure, or the function to execute.\n * If a name is provided, the 2nd argument should be the function.\n * If debug mode is on, the timing and the memory usage of the function\n * will be tracked and displayed on the console using the provided name.\n * @param fn The function to execute.\n */\n/** @doc {heading: 'Performance', subheading: 'Memory'} */\nexport function tidy(\n nameOrFn: string|ScopeFn, fn?: ScopeFn): T {\n return ENGINE.tidy(nameOrFn, fn);\n}\n\n/**\n * Disposes any `tf.Tensor`s found within the provided object.\n *\n * @param container an object that may be a `tf.Tensor` or may directly\n * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If\n * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing\n * happens. In general it is safe to pass any object here, except that\n * `Promise`s are not supported.\n */\n/** @doc {heading: 'Performance', subheading: 'Memory'} */\nexport function dispose(container: TensorContainer) {\n const tensors = getTensorsInContainer(container);\n tensors.forEach(tensor => tensor.dispose());\n}\n\n/**\n * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed\n * automatically.\n *\n * ```js\n * let b;\n * const y = tf.tidy(() => {\n * const one = tf.scalar(1);\n * const a = tf.scalar(2);\n *\n * // b will not be cleaned up by the tidy. a and one will be cleaned up\n * // when the tidy ends.\n * b = tf.keep(a.square());\n *\n * console.log('numTensors (in tidy): ' + tf.memory().numTensors);\n *\n * // The value returned inside the tidy function will return\n * // through the tidy, in this case to the variable y.\n * return b.add(one);\n * });\n *\n * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);\n * console.log('y:');\n * y.print();\n * console.log('b:');\n * b.print();\n * ```\n *\n * @param result The tensor to keep from being disposed.\n */\n/** @doc {heading: 'Performance', subheading: 'Memory'} */\nexport function keep(result: T): T {\n return ENGINE.keep(result);\n}\n\n/**\n * Executes `f()` and returns a promise that resolves with timing\n * information.\n *\n * The result is an object with the following properties:\n *\n * - `wallMs`: Wall execution time.\n * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the\n * WebGL backend and the query timer extension is not available, this will\n * return an error object.\n * - On `WebGL` The following additional properties exist:\n * - `uploadWaitMs`: CPU blocking time on texture uploads.\n * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).\n *\n * ```js\n * const x = tf.randomNormal([20, 20]);\n * const time = await tf.time(() => x.matMul(x));\n *\n * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);\n * ```\n *\n * @param f The function to execute and time.\n */\n/** @doc {heading: 'Performance', subheading: 'Timing'} */\nexport function time(f: () => void): Promise {\n return ENGINE.time(f);\n}\n\n/**\n * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and\n * executing operations on those tensors. Returns a promise that resolves\n * to a boolean if the backend initialization was successful.\n *\n * Note this disposes the current backend, if any, as well as any tensors\n * associated with it. A new backend is initialized, even if it is of the\n * same type as the previous one.\n *\n * @param backendName The name of the backend. Currently supports\n * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js\n * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).\n */\n/** @doc {heading: 'Backends'} */\nexport function setBackend(backendName: string): Promise {\n return ENGINE.setBackend(backendName);\n}\n\n/**\n * Returns a promise that resolves when the currently selected backend (or the\n * highest priority one) has initialized. Await this promise when you are using\n * a backend that has async initialization.\n */\n/** @doc {heading: 'Backends'} */\nexport function ready(): Promise {\n return ENGINE.ready();\n}\n\n/**\n * Returns the current backend name (cpu, webgl, etc). The backend is\n * responsible for creating tensors and executing operations on those tensors.\n */\n/** @doc {heading: 'Backends'} */\nexport function getBackend(): string {\n return ENGINE.backendName;\n}\n\n/**\n * Removes a backend and the registered factory.\n */\n/** @doc {heading: 'Backends'} */\nexport function removeBackend(name: string): void {\n ENGINE.removeBackend(name);\n}\n\n/**\n * Finds the backend registered under the provided name. Returns null if the\n * name is not in the registry, or the registration hasn't finished yet.\n */\nexport function findBackend(name: string): KernelBackend {\n return ENGINE.findBackend(name);\n}\n\n/**\n * Finds the backend factory registered under the provided name. Returns a\n * function that produces a new backend when called. Returns null if the name\n * is not in the registry.\n */\nexport function findBackendFactory(name: string): () =>\n KernelBackend | Promise {\n return ENGINE.findBackendFactory(name);\n}\n\n/**\n * Registers a global backend. The registration should happen when importing\n * a module file (e.g. when importing `backend_webgl.ts`), and is used for\n * modular builds (e.g. custom tfjs bundle with only webgl support).\n *\n * @param factory The backend factory function. When called, it should\n * return a backend instance, or a promise of an instance.\n * @param priority The priority of the backend (higher = more important).\n * In case multiple backends are registered, the priority is used to find\n * the best backend. Defaults to 1.\n * @return False if there is already a registered backend under this name, true\n * if not.\n */\n/** @doc {heading: 'Backends'} */\nexport function registerBackend(\n name: string, factory: () => KernelBackend | Promise,\n priority = 1): boolean {\n return ENGINE.registerBackend(name, factory, priority);\n}\n\n/**\n * Gets the current backend. If no backends have been initialized, this will\n * attempt to initialize the best backend. Will throw an error if the highest\n * priority backend has async initialization, in which case, you should call\n * 'await tf.ready()' before running other code.\n */\n/** @doc {heading: 'Backends'} */\nexport function backend(): KernelBackend {\n return ENGINE.backend;\n}\n\n/**\n * Sets the global platform.\n *\n * @param platformName The name of this platform.\n * @param platform A platform implementation.\n */\nexport function setPlatform(platformName: string, platform: Platform) {\n env().setPlatform(platformName, platform);\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from './environment';\n\nexport function warn(...msg: Array<{}>): void {\n if (!env().getBool('IS_TEST')) {\n console.warn(...msg);\n }\n}\n\nexport function log(...msg: Array<{}>): void {\n if (!env().getBool('IS_TEST')) {\n console.log(...msg);\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from './engine';\nimport {env} from './environment';\nimport {Tensor} from './tensor';\nimport {DataType, TensorLike} from './types';\nimport {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util';\n\nexport function inferShape(val: TensorLike, dtype?: DataType): number[] {\n let firstElem: typeof val = val;\n\n if (isTypedArray(val)) {\n return dtype === 'string' ? [] : [val.length];\n }\n if (!Array.isArray(val)) {\n return []; // Scalar.\n }\n const shape: number[] = [];\n\n while (Array.isArray(firstElem) ||\n isTypedArray(firstElem) && dtype !== 'string') {\n shape.push(firstElem.length);\n firstElem = firstElem[0];\n }\n if (Array.isArray(val) &&\n env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {\n deepAssertShapeConsistency(val, shape, []);\n }\n\n return shape;\n}\n\nfunction deepAssertShapeConsistency(\n val: TensorLike, shape: number[], indices: number[]) {\n indices = indices || [];\n if (!(Array.isArray(val)) && !isTypedArray(val)) {\n assert(\n shape.length === 0,\n () => `Element arr[${indices.join('][')}] is a primitive, ` +\n `but should be an array/TypedArray of ${shape[0]} elements`);\n return;\n }\n assert(\n shape.length > 0,\n () => `Element arr[${indices.join('][')}] should be a primitive, ` +\n `but is an array of ${val.length} elements`);\n assert(\n val.length === shape[0],\n () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +\n `elements, but has ${val.length} elements`);\n const subShape = shape.slice(1);\n for (let i = 0; i < val.length; ++i) {\n deepAssertShapeConsistency(val[i], subShape, indices.concat(i));\n }\n}\n\nfunction assertDtype(\n expectedDtype: DataType|'numeric', actualDType: DataType, argName: string,\n functionName: string) {\n if (expectedDtype == null) {\n return;\n }\n if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||\n expectedDtype === 'numeric' && actualDType === 'string') {\n throw new Error(\n `Argument '${argName}' passed to '${functionName}' must ` +\n `be ${expectedDtype} tensor, but got ${actualDType} tensor`);\n }\n}\n\nexport function convertToTensor(\n x: T|TensorLike, argName: string, functionName: string,\n parseAsDtype: DataType|'numeric' = 'numeric'): T {\n if (x instanceof Tensor) {\n assertDtype(parseAsDtype, x.dtype, argName, functionName);\n return x;\n }\n let inferredDtype = inferDtype(x);\n // If the user expects a bool/int/float, use that info to update the\n // inferredDtype when it is not a string.\n if (inferredDtype !== 'string' &&\n ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {\n inferredDtype = parseAsDtype as DataType;\n }\n assertDtype(parseAsDtype, inferredDtype, argName, functionName);\n\n if ((x == null) ||\n (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&\n typeof x !== 'boolean' && typeof x !== 'string')) {\n const type = x == null ? 'null' : (x as {}).constructor.name;\n throw new Error(\n `Argument '${argName}' passed to '${functionName}' must be a ` +\n `Tensor or TensorLike, but got '${type}'`);\n }\n const inferredShape = inferShape(x, inferredDtype);\n if (!isTypedArray(x) && !Array.isArray(x)) {\n x = [x] as number[];\n }\n const skipTypedArray = true;\n const values = inferredDtype !== 'string' ?\n toTypedArray(x, inferredDtype as DataType, env().getBool('DEBUG')) :\n flatten(x as string[], [], skipTypedArray) as string[];\n return ENGINE.makeTensor(values, inferredShape, inferredDtype) as T;\n}\n\nexport function convertToTensorArray(\n arg: Array, argName: string, functionName: string,\n parseAsDtype: DataType|'numeric' = 'numeric'): T[] {\n if (!Array.isArray(arg)) {\n throw new Error(\n `Argument ${argName} passed to ${functionName} must be a ` +\n '`Tensor[]` or `TensorLike[]`');\n }\n const tensors = arg as T[];\n return tensors.map(\n (t, i) => convertToTensor(t, `${argName}[${i}]`, functionName),\n parseAsDtype);\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as util from '../util';\n\n/**\n * Returns true if the axis specifies the inner most dimensions of the\n * array.\n */\nexport function axesAreInnerMostDims(axes: number[], rank: number): boolean {\n for (let i = 0; i < axes.length; ++i) {\n if (axes[axes.length - i - 1] !== rank - 1 - i) {\n return false;\n }\n }\n return true;\n}\n\nexport function combineLocations(\n outputLoc: number[], reduceLoc: number[], axes: number[]): number[] {\n const rank = outputLoc.length + reduceLoc.length;\n const loc = [];\n let outIdx = 0;\n let reduceIdx = 0;\n   for (let dim = 0; dim < rank; dim++) {\n if (axes.indexOf(dim) === -1) {\n loc.push(outputLoc[outIdx++]);\n } else {\n loc.push(reduceLoc[reduceIdx++]);\n }\n }\n return loc;\n}\n\nexport function computeOutAndReduceShapes(\n aShape: number[], axes: number[]): [number[], number[]] {\n const outShape = [];\n const rank = aShape.length;\n for (let dim = 0; dim < rank; dim++) {\n if (axes.indexOf(dim) === -1) {\n outShape.push(aShape[dim]);\n }\n }\n const reduceShape = axes.map(dim => aShape[dim]);\n return [outShape, reduceShape];\n}\n\nexport function expandShapeToKeepDim(\n shape: number[], axes: number[]): number[] {\n const reduceSubShape = axes.map(x => 1);\n return combineLocations(shape, reduceSubShape, axes);\n}\n\nexport function assertAxesAreInnerMostDims(\n msg: string, axes: number[], rank: number): void {\n util.assert(\n axesAreInnerMostDims(axes, rank),\n () => `${msg} supports only inner-most axes for now. ` +\n `Got axes ${axes} and rank-${rank} input.`);\n}\n\n/**\n * Returns the axes permutation to be used with `tf.transpose`, if such\n * permutation is necessary. Otherwise it returns null. This method is used by\n * operations that operate only on inner-most axes.\n */\nexport function getAxesPermutation(axes: number[], rank: number): number[]|\n null {\n if (axesAreInnerMostDims(axes, rank)) {\n return null;\n }\n const result: number[] = [];\n for (let i = 0; i < rank; ++i) {\n if (axes.indexOf(i) === -1) {\n result.push(i);\n }\n }\n axes.forEach(axis => result.push(axis));\n return result;\n}\n\n/** Returns the axes permutation that undoes the original permutation. */\nexport function getUndoAxesPermutation(axes: number[]): number[] {\n return axes.map((axis, i) => [i, axis])\n .sort((a, b) => a[1] - b[1])\n .map(x => x[0]);\n}\n\nexport function getInnerMostAxes(numAxes: number, rank: number): number[] {\n const res: number[] = [];\n for (let i = rank - numAxes; i < rank; ++i) {\n res.push(i);\n }\n return res;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as util from '../util';\n\nexport function assertParamsConsistent(shapes: number[][], axis: number) {\n const rank = shapes[0].length;\n shapes.forEach((shape, i) => {\n util.assert(\n shape.length === rank,\n () =>\n `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +\n `as the rank of the rest (${rank})`);\n });\n\n util.assert(\n axis >= 0 && axis < rank,\n () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);\n\n const firstShape = shapes[0];\n shapes.forEach((shape, i) => {\n for (let r = 0; r < rank; r++) {\n util.assert(\n (r === axis) || (shape[r] === firstShape[r]),\n () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +\n `does not match the shape of the rest (${firstShape}) ` +\n `along the non-concatenated axis ${i}.`);\n }\n });\n}\n\nexport function computeOutShape(shapes: number[][], axis: number): number[] {\n const outputShape = shapes[0].slice();\n for (let i = 1; i < shapes.length; i++) {\n outputShape[axis] += shapes[i][axis];\n }\n return outputShape;\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE} from '../engine';\n\n/**\n * Used for wrapping functions that perform math operations on\n * Tensors. The function will be wrapped in a named scope that cleans all\n * memory usage after the function is done.\n */\nexport function op(f: {[name: string]: T}): T {\n const keys = Object.keys(f);\n if (keys.length !== 1) {\n throw new Error(\n `Please provide an object with a single key ` +\n `(operation name) mapping to a function. Got an object with ` +\n `${keys.length} keys.`);\n }\n\n let opName = keys[0];\n const fn = f[opName];\n\n // Strip the underscore from the end of the function name.\n if (opName.endsWith('_')) {\n opName = opName.substring(0, opName.length - 1);\n }\n\n // tslint:disable-next-line:no-any\n const f2 = (...args: any[]) => {\n ENGINE.startScope(opName);\n try {\n const result = fn(...args);\n if (result instanceof Promise) {\n console.error('Cannot return a Promise inside of tidy.');\n }\n ENGINE.endScope(result);\n return result;\n } catch (ex) {\n ENGINE.endScope(null);\n throw ex;\n }\n };\n Object.defineProperty(f2, 'name', {value: opName, configurable: true});\n\n // tslint:disable-next-line:no-any\n return f2 as any as T;\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport {op} from './operation';\n\n/**\n * Converts two real numbers to a complex number.\n *\n * Given a tensor `real` representing the real part of a complex number, and a\n * tensor `imag` representing the imaginary part of a complex number, this\n * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],\n * where r represents the real part and i represents the imag part.\n *\n * The input tensors real and imag must have the same shape.\n *\n * ```js\n * const real = tf.tensor1d([2.25, 3.25]);\n * const imag = tf.tensor1d([4.75, 5.75]);\n * const complex = tf.complex(real, imag);\n *\n * complex.print();\n * ```\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction complex_(real: T|TensorLike, imag: T|TensorLike): T {\n const $real = convertToTensor(real, 'real', 'complex');\n const $imag = convertToTensor(imag, 'imag', 'complex');\n util.assertShapesMatch(\n $real.shape, $imag.shape,\n `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` +\n `must match in call to tf.complex().`);\n\n return ENGINE.runKernelFunc(\n backend => backend.complex($real, $imag), {$real, $imag});\n}\n\n/**\n * Returns the real part of a complex (or real) tensor.\n *\n * Given a tensor input, this operation returns a tensor of type float that is\n * the real part of each element in input considered as a complex number.\n *\n * If the input is real, it simply makes a clone.\n *\n * ```js\n * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);\n * tf.real(x).print();\n * ```\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction real_(input: T|TensorLike): T {\n const $input = convertToTensor(input, 'input', 'real');\n\n return ENGINE.runKernelFunc(backend => backend.real($input), {$input});\n}\n\n/**\n * Returns the imaginary part of a complex (or real) tensor.\n *\n * Given a tensor input, this operation returns a tensor of type float that is\n * the imaginary part of each element in input considered as a complex number.\n * If input is real, a tensor of all zeros is returned.\n *\n * ```js\n * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);\n * tf.imag(x).print();\n * ```\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction imag_(input: T|TensorLike): T {\n const $input = convertToTensor(input, 'input', 'imag');\n\n return ENGINE.runKernelFunc(backend => backend.imag($input), {$input});\n}\n\nexport const complex = op({complex_});\nexport const real = op({real_});\nexport const imag = op({imag_});\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {env} from '../environment';\n\nimport {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D, Variable} from '../tensor';\nimport {convertToTensor, inferShape} from '../tensor_util_env';\nimport {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types';\nimport {DataType, Rank, ShapeMap} from '../types';\nimport {assert, assertNonNegativeIntegerDimensions, assertNonNull, flatten, inferDtype, isTypedArray, makeOnesTypedArray, makeZerosTypedArray, sizeFromShape, toTypedArray} from '../util';\nimport {complex, imag, real} from './complex_ops';\nimport {op} from './operation';\n\n/**\n * Creates a `tf.Tensor` with the provided values, shape and dtype.\n *\n * ```js\n * // Pass an array of values to create a vector.\n * tf.tensor([1, 2, 3, 4]).print();\n * ```\n *\n * ```js\n * // Pass a nested array of values to make a matrix or a higher\n * // dimensional tensor.\n * tf.tensor([[1, 2], [3, 4]]).print();\n * ```\n *\n * ```js\n * // Pass a flat array and specify a shape yourself.\n * tf.tensor([1, 2, 3, 4], [2, 2]).print();\n * ```\n *\n * @param values The values of the tensor. Can be nested array of numbers,\n * or a flat array, or a `TypedArray`. If the values are strings,\n * they will be encoded as utf-8 and kept as `Uint8Array[]`.\n * @param shape The shape of the tensor. Optional. If not provided,\n * it is inferred from `values`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor(\n values: TensorLike, shape?: ShapeMap[R], dtype?: DataType): Tensor {\n const inferredShape = inferShape(values, dtype);\n return makeTensor(values, shape, inferredShape, dtype) as Tensor;\n}\n\n/** This is shared code across all tensor creation methods. */\nfunction makeTensor(\n values: TensorLike, shape: number[], inferredShape: number[],\n dtype?: DataType): Tensor {\n if (dtype == null) {\n dtype = inferDtype(values);\n }\n if (dtype === 'complex64') {\n throw new Error(\n `Cannot construct a complex64 tensor directly. ` +\n `Please use tf.complex(real, imag).`);\n }\n if (!isTypedArray(values) && !Array.isArray(values) &&\n typeof values !== 'number' && typeof values !== 'boolean' &&\n typeof values !== 'string') {\n throw new Error(\n 'values passed to tensor(values) must be a number/boolean/string or ' +\n 'an array of numbers/booleans/strings, or a TypedArray');\n }\n if (shape != null) {\n assertNonNegativeIntegerDimensions(shape);\n\n const providedSize = sizeFromShape(shape);\n const inferredSize = sizeFromShape(inferredShape);\n assert(\n providedSize === inferredSize,\n () =>\n `Based on the provided shape, [${shape}], the tensor should have ` +\n `${providedSize} values but has ${inferredSize}`);\n\n for (let i = 0; i < inferredShape.length; ++i) {\n const inferred = inferredShape[i];\n const flatDimsDontMatch = i === inferredShape.length - 1 ?\n inferred !== sizeFromShape(shape.slice(i)) :\n true;\n assert(\n inferredShape[i] === shape[i] || !flatDimsDontMatch,\n () => `Error creating a new Tensor. Inferred shape ` +\n `(${inferredShape}) does not match the provided ` +\n `shape (${shape}). `);\n }\n }\n\n if (!isTypedArray(values) && !Array.isArray(values)) {\n values = [values] as number[];\n }\n\n shape = shape || inferredShape;\n values = dtype !== 'string' ?\n toTypedArray(values, dtype, env().getBool('DEBUG')) :\n flatten(values as string[], [], true) as string[];\n return ENGINE.makeTensor(values as TypedArray, shape, dtype);\n}\n\n/**\n * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.scalar` as it makes the code more readable.\n *\n * ```js\n * tf.scalar(3.14).print();\n * ```\n *\n * @param value The value of the scalar.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction scalar(\n value: number|boolean|string|Uint8Array, dtype?: DataType): Scalar {\n if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&\n dtype !== 'complex64') {\n throw new Error(\n 'Error creating a new Scalar: value must be a primitive ' +\n '(number|boolean|string)');\n }\n if (dtype === 'string' && isTypedArray(value) &&\n !(value instanceof Uint8Array)) {\n throw new Error(\n 'When making a scalar from encoded string, ' +\n 'the value must be `Uint8Array`.');\n }\n const shape: number[] = [];\n const inferredShape: number[] = [];\n return makeTensor(value, shape, inferredShape, dtype) as Scalar;\n}\n\n/**\n * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.tensor1d` as it makes the code more readable.\n *\n * ```js\n * tf.tensor1d([1, 2, 3]).print();\n * ```\n *\n * @param values The values of the tensor. Can be array of numbers,\n * or a `TypedArray`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor1d(values: TensorLike1D, dtype?: DataType): Tensor1D {\n assertNonNull(values);\n const inferredShape = inferShape(values, dtype);\n if (inferredShape.length !== 1) {\n throw new Error('tensor1d() requires values to be a flat/TypedArray');\n }\n const shape: number[] = null;\n return makeTensor(values, shape, inferredShape, dtype) as Tensor1D;\n}\n\n/**\n * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.tensor2d` as it makes the code more readable.\n *\n * ```js\n * // Pass a nested array.\n * tf.tensor2d([[1, 2], [3, 4]]).print();\n * ```\n * ```js\n * // Pass a flat array and specify a shape.\n * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();\n * ```\n *\n * @param values The values of the tensor. Can be nested array of numbers,\n * or a flat array, or a `TypedArray`.\n * @param shape The shape of the tensor. If not provided, it is inferred from\n * `values`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor2d(\n values: TensorLike2D, shape?: [number, number],\n dtype?: DataType): Tensor2D {\n assertNonNull(values);\n if (shape != null && shape.length !== 2) {\n throw new Error('tensor2d() requires shape to have two numbers');\n }\n const inferredShape = inferShape(values, dtype);\n if (inferredShape.length !== 2 && inferredShape.length !== 1) {\n throw new Error(\n 'tensor2d() requires values to be number[][] or flat/TypedArray');\n }\n if (inferredShape.length === 1 && shape == null) {\n throw new Error(\n 'tensor2d() requires shape to be provided when `values` ' +\n 'are a flat/TypedArray');\n }\n return makeTensor(values, shape, inferredShape, dtype) as Tensor2D;\n}\n\n/**\n * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.tensor3d` as it makes the code more readable.\n *\n * ```js\n * // Pass a nested array.\n * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();\n * ```\n * ```js\n * // Pass a flat array and specify a shape.\n * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();\n * ```\n *\n * @param values The values of the tensor. Can be nested array of numbers,\n * or a flat array, or a `TypedArray`.\n * @param shape The shape of the tensor. If not provided, it is inferred from\n * `values`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor3d(\n values: TensorLike3D, shape?: [number, number, number],\n dtype?: DataType): Tensor3D {\n assertNonNull(values);\n if (shape != null && shape.length !== 3) {\n throw new Error('tensor3d() requires shape to have three numbers');\n }\n const inferredShape = inferShape(values, dtype);\n if (inferredShape.length !== 3 && inferredShape.length !== 1) {\n throw new Error(\n 'tensor3d() requires values to be number[][][] or flat/TypedArray');\n }\n if (inferredShape.length === 1 && shape == null) {\n throw new Error(\n 'tensor3d() requires shape to be provided when `values` ' +\n 'are a flat array');\n }\n return makeTensor(values, shape, inferredShape, dtype) as Tensor3D;\n}\n\n/**\n * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.tensor4d` as it makes the code more readable.\n *\n * ```js\n * // Pass a nested array.\n * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();\n * ```\n * ```js\n * // Pass a flat array and specify a shape.\n * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();\n * ```\n *\n * @param values The values of the tensor. Can be nested array of numbers,\n * or a flat array, or a `TypedArray`.\n * @param shape The shape of the tensor. Optional. If not provided,\n * it is inferred from `values`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor4d(\n values: TensorLike4D, shape?: [number, number, number, number],\n dtype?: DataType): Tensor4D {\n assertNonNull(values);\n if (shape != null && shape.length !== 4) {\n throw new Error('tensor4d() requires shape to have four numbers');\n }\n const inferredShape = inferShape(values, dtype);\n if (inferredShape.length !== 4 && inferredShape.length !== 1) {\n throw new Error(\n 'tensor4d() requires values to be number[][][][] or flat/TypedArray');\n }\n if (inferredShape.length === 1 && shape == null) {\n throw new Error(\n 'tensor4d() requires shape to be provided when `values` ' +\n 'are a flat array');\n }\n return makeTensor(values, shape, inferredShape, dtype) as Tensor4D;\n}\n\n/**\n * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.tensor5d` as it makes the code more readable.\n *\n * ```js\n * // Pass a nested array.\n * tf.tensor5d([[[[[1], [2]], [[3], [4]]]]]).print();\n * ```\n * ```js\n * // Pass a flat array and specify a shape.\n * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();\n * ```\n *\n * @param values The values of the tensor. Can be nested array of numbers,\n * or a flat array, or a `TypedArray`.\n * @param shape The shape of the tensor. Optional. If not provided,\n * it is inferred from `values`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor5d(\n values: TensorLike5D, shape?: [number, number, number, number, number],\n dtype?: DataType): Tensor5D {\n assertNonNull(values);\n if (shape != null && shape.length !== 5) {\n throw new Error('tensor5d() requires shape to have five numbers');\n }\n const inferredShape = inferShape(values, dtype);\n if (inferredShape.length !== 5 && inferredShape.length !== 1) {\n throw new Error(\n 'tensor5d() requires values to be ' +\n 'number[][][][][] or flat/TypedArray');\n }\n if (inferredShape.length === 1 && shape == null) {\n throw new Error(\n 'tensor5d() requires shape to be provided when `values` ' +\n 'are a flat array');\n }\n return makeTensor(values, shape, inferredShape, dtype) as Tensor5D;\n}\n\n/**\n * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.\n *\n * The same functionality can be achieved with `tf.tensor`, but in general\n * we recommend using `tf.tensor6d` as it makes the code more readable.\n *\n * ```js\n * // Pass a nested array.\n * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();\n * ```\n * ```js\n * // Pass a flat array and specify a shape.\n * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();\n * ```\n *\n * @param values The values of the tensor. Can be nested array of numbers,\n * or a flat array, or a `TypedArray`.\n * @param shape The shape of the tensor. Optional. If not provided,\n * it is inferred from `values`.\n * @param dtype The data type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction tensor6d(\n values: TensorLike6D,\n shape?: [number, number, number, number, number, number],\n dtype?: DataType): Tensor6D {\n assertNonNull(values);\n if (shape != null && shape.length !== 6) {\n throw new Error('tensor6d() requires shape to have six numbers');\n }\n const inferredShape = inferShape(values, dtype);\n if (inferredShape.length !== 6 && inferredShape.length !== 1) {\n throw new Error(\n 'tensor6d() requires values to be number[][][][][][] or ' +\n 'flat/TypedArray');\n }\n if (inferredShape.length === 1 && shape == null) {\n throw new Error(\n 'tensor6d() requires shape to be provided when `values` ' +\n 'are a flat array');\n }\n shape = shape ||\n inferredShape as [number, number, number, number, number, number];\n return makeTensor(values, shape, inferredShape, dtype) as Tensor6D;\n}\n\n/**\n * Creates a new variable with the provided initial value.\n * ```js\n * const x = tf.variable(tf.tensor([1, 2, 3]));\n * x.assign(tf.tensor([4, 5, 6]));\n *\n * x.print();\n * ```\n *\n * @param initialValue Initial value for the tensor.\n * @param trainable If true, optimizers are allowed to update it.\n * @param name Name of the variable. Defaults to a unique id.\n * @param dtype If set, initialValue will be converted to the given type.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction variable(\n initialValue: Tensor, trainable = true, name?: string,\n dtype?: DataType): Variable {\n return ENGINE.makeVariable(initialValue, trainable, name, dtype) as\n Variable;\n}\n\n/**\n * Creates a `tf.Tensor` with all elements set to 1.\n *\n * ```js\n * tf.ones([2, 2]).print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param dtype The type of an element in the resulting tensor. Defaults to\n * 'float'.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction ones(\n shape: ShapeMap[R], dtype: DataType = 'float32'): Tensor {\n if (dtype === 'complex64') {\n const real = ones(shape, 'float32');\n const imag = zeros(shape, 'float32');\n return complex(real, imag);\n }\n const values = makeOnesTypedArray(sizeFromShape(shape), dtype);\n return ENGINE.makeTensor(values, shape, dtype) as Tensor;\n}\n\n/**\n * Creates a `tf.Tensor` with all elements set to 0.\n *\n * ```js\n * tf.zeros([2, 2]).print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param dtype The type of an element in the resulting tensor. Can\n * be 'float32', 'int32' or 'bool'. Defaults to 'float'.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction zeros(\n shape: ShapeMap[R], dtype: DataType = 'float32'): Tensor {\n if (dtype === 'complex64') {\n const real = zeros(shape, 'float32');\n const imag = zeros(shape, 'float32');\n return complex(real, imag);\n }\n const values = makeZerosTypedArray(sizeFromShape(shape), dtype);\n return ENGINE.makeTensor(values, shape, dtype) as Tensor;\n}\n\n/**\n * Creates a `tf.Tensor` filled with a scalar value.\n *\n * ```js\n * tf.fill([2, 2], 4).print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param value The scalar value to fill the tensor with.\n * @param dtype The type of an element in the resulting tensor. Defaults to\n * 'float'.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction fill(\n shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor {\n return ENGINE.runKernelFunc(backend => backend.fill(shape, value, dtype), {});\n}\n\n/**\n * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the\n * given tensor.\n *\n * ```js\n * const x = tf.tensor([1, 2]);\n * tf.onesLike(x).print();\n * ```\n * @param x A tensor.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction onesLike_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'onesLike');\n if ($x.dtype === 'complex64') {\n const r = onesLike(real($x));\n const i = zerosLike(imag($x));\n return complex(r, i);\n }\n const der = (dy: T, saved: Tensor[]) => ({x: () => zerosLike(dy)});\n return ENGINE.runKernelFunc(\n backend => backend.onesLike($x), {x: $x}, der, 'OnesLike') as T;\n}\n\n/**\n * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the\n * given tensor.\n *\n * ```js\n * const x = tf.tensor([1, 2]);\n * tf.zerosLike(x).print();\n * ```\n *\n * @param x The tensor of required shape.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction zerosLike_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'zerosLike');\n const der = (dy: T, saved: Tensor[]) => ({x: () => zerosLike(dy)});\n return ENGINE.runKernelFunc(\n backend => backend.zerosLike($x), {x: $x}, der, 'ZerosLike') as T;\n}\n\n/**\n * Return an evenly spaced sequence of numbers over the given interval.\n *\n * ```js\n * tf.linspace(0, 9, 10).print();\n * ```\n * @param start The start value of the sequence.\n * @param stop The end value of the sequence.\n * @param num The number of values to generate.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction linspace(start: number, stop: number, num: number): Tensor1D {\n if (num <= 0) {\n throw new Error('The number of values should be positive.');\n }\n return ENGINE.runKernelFunc(\n backend => backend.linspace(start, stop, num), {});\n}\n\n/**\n * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.\n *\n * The tensor is a is half-open interval meaning it includes start, but\n * excludes stop. Decrementing ranges and negative step values are also\n * supported.\n *\n * ```js\n * tf.range(0, 9, 2).print();\n * ```\n *\n * @param start An integer start value\n * @param stop An integer stop value\n * @param step An integer increment (will default to 1 or -1)\n * @param dtype The data type of the output tensor. Defaults to 'float32'.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction range(\n start: number, stop: number, step = 1,\n dtype: 'float32'|'int32' = 'float32'): Tensor1D {\n if (step === 0) {\n throw new Error('Cannot have a step of zero');\n }\n\n const sameStartStop = start === stop;\n const increasingRangeNegativeStep = start < stop && step < 0;\n const decreasingRangePositiveStep = stop < start && step > 1;\n\n if (sameStartStop || increasingRangeNegativeStep ||\n decreasingRangePositiveStep) {\n return zeros([0], dtype);\n }\n\n const numElements = Math.abs(Math.ceil((stop - start) / step));\n const values = makeZerosTypedArray(numElements, dtype);\n\n if (stop < start && step === 1) {\n // Auto adjust the step's sign if it hasn't been set\n // (or was set to 1)\n step = -1;\n }\n\n values[0] = start;\n for (let i = 1; i < values.length; i++) {\n values[i] = values[i - 1] + step;\n }\n\n return tensor1d(values, dtype);\n}\n\nexport {\n fill,\n linspace,\n ones,\n range,\n scalar,\n tensor,\n tensor1d,\n tensor2d,\n tensor3d,\n tensor4d,\n tensor5d,\n tensor6d,\n variable,\n zeros\n};\n\nexport const onesLike = op({onesLike_});\nexport const zerosLike = op({zerosLike_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';\nimport {convertToTensor, convertToTensorArray} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assert, sizeFromShape} from '../util';\nimport {parseAxisParam} from '../util';\nimport {assertParamsConsistent, computeOutShape} from './concat_util';\nimport {op} from './operation';\nimport {tensor} from './tensor_ops';\n\n/**\n * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.\n *\n * For example, if:\n * A: shape(3) = |r1, g1, b1|\n * B: shape(2) = |r2, g2|\n * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|\n *\n * @param tensors A list of`tf.Tensor`s to concatenate.\n * @return The concatenated array.\n */\nfunction concat1d_(tensors: Array): Tensor1D {\n return concat(tensors, 0 /* axis */);\n}\n\n/**\n * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.\n *\n * For example, if:\n * A: shape(2, 3) = | r1, g1, b1 |\n * | r2, g2, b2 |\n *\n * B: shape(2, 3) = | r3, g3, b3 |\n * | r4, g4, b4 |\n *\n * C = tf.concat2d([A, B], axis)\n *\n * if axis = 0:\n * C: shape(4, 3) = | r1, g1, b1 |\n * | r2, g2, b2 |\n * | r3, g3, b3 |\n * | r4, g4, b4 |\n *\n * if axis = 1:\n * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |\n * | r2, g2, b2, r4, g4, b4 |\n *\n *\n * @param tensors A list of `tf.Tensor`s to concatenate.\n * @param axis The axis to concatenate along.\n * @return The concatenated array.\n */\nfunction concat2d_(\n tensors: Array, axis: number): Tensor2D {\n return concat(tensors, axis);\n}\n\n/**\n * Concatenates a list of `tf.Tensor3D`s along an axis.\n * See `concat` for details.\n *\n * For example, if:\n * A: shape(2, 1, 3) = | r1, g1, b1 |\n * | r2, g2, b2 |\n *\n * B: shape(2, 1, 3) = | r3, g3, b3 |\n * | r4, g4, b4 |\n *\n * C = tf.concat3d([A, B], axis)\n *\n * if axis = 0:\n * C: shape(4, 1, 3) = | r1, g1, b1 |\n * | r2, g2, b2 |\n * | r3, g3, b3 |\n * | r4, g4, b4 |\n *\n * if axis = 1:\n * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |\n * | r2, g2, b2, r4, g4, b4 |\n *\n * if axis = 2:\n * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |\n * | r2, g2, b2, r4, g4, b4 |\n *\n * @param tensors A list of`tf.Tensor`s to concatenate.\n * @param axis The axis to concate along.\n * @return The concatenated array.\n */\nfunction concat3d_(\n tensors: Array, axis: number): Tensor3D {\n return concat(tensors, axis);\n}\n\n/**\n * Concatenates a list of `tf.Tensor4D`s along an axis.\n * See `concat` for details.\n *\n * @param tensors A list of `tf.Tensor`s to concatenate.\n * @param axis The axis to concate along.\n * @return The concatenated array.\n */\nfunction concat4d_(\n tensors: Array, axis: number): Tensor4D {\n return concat(tensors, axis);\n}\n\n/**\n * Concatenates a list of `tf.Tensor`s along a given axis.\n *\n * The tensors ranks and types must match, and their sizes must match in all\n * dimensions except `axis`.\n *\n * Also available are stricter rank-specific methods that assert that\n * `tensors` are of the given rank:\n * - `tf.concat1d`\n * - `tf.concat2d`\n * - `tf.concat3d`\n * - `tf.concat4d`\n *\n * Except `tf.concat1d` (which does not have axis param), all methods have\n * same signature as this method.\n *\n * ```js\n * const a = tf.tensor1d([1, 2]);\n * const b = tf.tensor1d([3, 4]);\n * a.concat(b).print(); // or a.concat(b)\n * ```\n *\n * ```js\n * const a = tf.tensor1d([1, 2]);\n * const b = tf.tensor1d([3, 4]);\n * const c = tf.tensor1d([5, 6]);\n * tf.concat([a, b, c]).print();\n * ```\n *\n * ```js\n * const a = tf.tensor2d([[1, 2], [10, 20]]);\n * const b = tf.tensor2d([[3, 4], [30, 40]]);\n * const axis = 1;\n * tf.concat([a, b], axis).print();\n * ```\n * @param tensors A list of tensors to concatenate.\n * @param axis The axis to concate along. Defaults to 0 (the first dim).\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction concat_(tensors: Array, axis = 0): T {\n assert(tensors.length >= 1, () => 'Pass at least one tensor to concat');\n let $tensors = convertToTensorArray(tensors, 'tensors', 'concat');\n if ($tensors[0].dtype === 'complex64') {\n $tensors.forEach(tensor => {\n if (tensor.dtype !== 'complex64') {\n throw new Error(`Cannot concatenate complex64 tensors with a tensor\n with dtype ${tensor.dtype}. `);\n }\n });\n }\n\n axis = parseAxisParam(axis, $tensors[0].shape)[0];\n const outShape = computeOutShape($tensors.map(t => t.shape), axis);\n if (sizeFromShape(outShape) === 0) {\n return tensor([], outShape) as T;\n }\n // Keep only non-empty tensors (ignore tensors with 0 in their shape).\n $tensors = $tensors.filter(t => t.size > 0);\n if ($tensors.length === 1) {\n return $tensors[0];\n }\n\n const shapes = $tensors.map(t => t.shape);\n assertParamsConsistent(shapes, axis);\n const der = (dy: T) => {\n const sizeSplits = shapes.map(s => s[axis]);\n const derTensors = split(dy, sizeSplits, axis);\n return derTensors.map(t => () => t) as {};\n };\n const inputs = $tensors as {};\n const attr = {axis};\n return ENGINE.runKernelFunc(\n backend => backend.concat($tensors, axis) as T, inputs, der, 'Concat',\n attr);\n}\n\n/**\n * Splits a `tf.Tensor` into sub tensors.\n *\n * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`\n * into `numOrSizeSplits` smaller tensors.\n * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.\n *\n * If `numOrSizeSplits` is a number array, splits `x` into\n * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the\n * same size as `x` except along dimension `axis` where the size is\n * `numOrSizeSplits[i]`.\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);\n * const [a, b] = tf.split(x, 2, 1);\n * a.print();\n * b.print();\n *\n * const [c, d, e] = tf.split(x, [1, 2, 1], 1);\n * c.print();\n * d.print();\n * e.print();\n * ```\n *\n * @param x The input tensor to split.\n * @param numOrSizeSplits Either an integer indicating the number of\n * splits along the axis or an array of integers containing the sizes of\n * each output tensor along the axis. If a number then it must evenly divide\n * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.\n * @param axis The dimension along which to split. Defaults to 0 (the first\n * dim).\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction split_(\n x: T|TensorLike, numOrSizeSplits: number[]|number, axis = 0): T[] {\n const $x = convertToTensor(x, 'x', 'split');\n\n axis = parseAxisParam(axis, $x.shape)[0];\n let splitSizes: number[];\n if (typeof (numOrSizeSplits) === 'number') {\n assert(\n $x.shape[axis] % numOrSizeSplits === 0,\n () => 'Number of splits must evenly divide the axis.');\n splitSizes =\n new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits);\n } else {\n assert(\n $x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b),\n () => 'The sum of sizes must match the size of the axis dimension.');\n splitSizes = numOrSizeSplits;\n }\n const der = (dy: T[]) => ({$x: () => concat(dy, axis)});\n return ENGINE.runKernelFunc(\n backend => backend.split($x, splitSizes, axis), {$x}, der);\n}\n\nexport const concat = op({concat_});\nexport const concat1d = op({concat1d_});\nexport const concat2d = op({concat2d_});\nexport const concat3d = op({concat3d_});\nexport const concat4d = op({concat4d_});\nexport const split = op({split_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor4D, TensorBuffer} from '../tensor';\nimport {convertToTensor, convertToTensorArray} from '../tensor_util_env';\nimport {DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike4D} from '../types';\nimport * as util from '../util';\nimport {getAxesPermutation, getInnerMostAxes} from './axis_util';\nimport {concat} from './concat_split';\nimport {op} from './operation';\n\n/**\n * Reshapes a `tf.Tensor` to a given shape.\n *\n * Given an input tensor, returns a new tensor with the same values as the\n * input tensor with shape `shape`.\n *\n * If one component of shape is the special value -1, the size of that\n * dimension is computed so that the total size remains constant. In\n * particular, a shape of [-1] flattens into 1-D. At most one component of\n * shape can be -1.\n *\n * If shape is 1-D or higher, then the operation returns a tensor with shape\n * shape filled with the values of tensor. In this case, the number of\n * elements implied by shape must be the same as the number of elements in\n * tensor.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n * x.reshape([2, 2]).print();\n * ```\n *\n * @param x The input tensor to be reshaped.\n * @param shape An array of integers defining the output tensor shape.\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction reshape_(\n x: Tensor|TensorLike, shape: ShapeMap[R2]): Tensor {\n const $x = convertToTensor(x, 'x', 'reshape', null);\n shape = util.inferFromImplicitShape(shape, $x.size) as ShapeMap[R2];\n util.assert(\n $x.size === util.sizeFromShape(shape),\n () => 'new shape and old shape must have the same number of elements.');\n\n const grad = (dy: Tensor) => {\n return {x: () => dy.reshape($x.shape)};\n };\n const attrs = {shape};\n return ENGINE.runKernelFunc(\n backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs);\n}\n\n/**\n * Removes dimensions of size 1 from the shape of a `tf.Tensor`.\n *\n * ```js\n * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);\n * x.squeeze().print();\n * ```\n *\n * @param x The input tensor to be squeezed.\n * @param axis An optional list of numbers. If specified, only\n * squeezes the dimensions listed. The dimension index starts at 0. It\n * is an error to squeeze a dimension that is not 1.\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction squeeze_(x: Tensor|TensorLike, axis?: number[]): T {\n const $x = convertToTensor(x, 'x', 'squeeze');\n return reshape($x, util.squeezeShape($x.shape, axis).newShape) as T;\n}\n\n/**\n * Casts a `tf.Tensor` to a new dtype.\n *\n * ```js\n * const x = tf.tensor1d([1.5, 2.5, 3]);\n * tf.cast(x, 'int32').print();\n * ```\n * @param x The input tensor to be casted.\n * @param dtype The dtype to cast the input tensor to.\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction cast_(x: T|TensorLike, dtype: DataType): T {\n const $x = convertToTensor(x, 'x', 'cast');\n\n // Sanity checks.\n if (!util.isValidDtype(dtype)) {\n throw new Error(`Failed to cast to unknown dtype ${dtype}`);\n }\n if (dtype === 'string' && $x.dtype !== 'string' ||\n dtype !== 'string' && $x.dtype === 'string') {\n throw new Error('Only strings can be casted to strings');\n }\n\n const grad = (dy: T) => {\n return {x: () => dy.clone()};\n };\n const attrs = {dtype};\n return ENGINE.runKernelFunc(\n backend => backend.cast($x, dtype), {x: $x}, grad, 'Cast', attrs);\n}\n\n/**\n * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.\n *\n * ```js\n * const a = tf.tensor1d([1, 2]);\n * const b = tf.tensor1d([3, 4]);\n * const c = tf.tensor1d([5, 6]);\n * tf.stack([a, b, c]).print();\n * ```\n *\n * @param tensors A list of tensor objects with the same shape and dtype.\n * @param axis The axis to stack along. Defaults to 0 (the first dim).\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction stack_(\n tensors: Array, axis = 0): Tensor {\n const $tensors = convertToTensorArray(tensors, 'tensors', 'stack');\n\n util.assert(\n $tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');\n if ($tensors.length === 1) {\n return $tensors[0].expandDims(axis);\n }\n const rank = $tensors[0].rank;\n const shape = $tensors[0].shape;\n const dtype = $tensors[0].dtype;\n\n util.assert(axis <= rank, () => 'Axis must be <= rank of the tensor');\n\n $tensors.forEach(t => {\n util.assertShapesMatch(\n shape, t.shape,\n 'All tensors passed to stack must have matching shapes');\n });\n\n $tensors.forEach(t => {\n util.assert(\n dtype === t.dtype,\n () => 'All tensors passed to stack must have matching dtypes');\n });\n const expandedTensors = $tensors.map(t => t.expandDims(axis));\n return concat(expandedTensors, axis);\n}\n\n/**\n * This operation reshapes the \"batch\" dimension 0 into `M + 1` dimensions of\n * shape `blockShape + [batch]`, interleaves these blocks back into the grid\n * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with\n * the same rank as the input. The spatial dimensions of this intermediate\n * result are then optionally cropped according to `crops` to produce the\n * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise\n * description.\n *\n * ```js\n * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);\n * const blockShape = [2, 2];\n * const crops = [[0, 0], [0, 0]];\n *\n * x.batchToSpaceND(blockShape, crops).print();\n * ```\n *\n * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +\n * remainingShape`, where spatialShape has `M` dimensions.\n * @param blockShape A 1-D array. Must have shape `[M]`, all values must\n * be >= 1.\n * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.\n * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input\n * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required\n * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`\n *\n * This operation is equivalent to the following steps:\n *\n * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,\n * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,\n * x.shape[N-1]]`\n *\n * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /\n * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],\n * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`\n *\n * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /\n * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *\n * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`\n *\n * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`\n * according to `crops` to produce the output of shape: `[batch /\n * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],\n * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -\n * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction batchToSpaceND_(\n x: T|TensorLike, blockShape: number[], crops: number[][]): T {\n const $x = convertToTensor(x, 'x', 'batchToSpaceND');\n const prod = blockShape.reduce((a, b) => a * b);\n\n util.assert(\n $x.rank >= 1 + blockShape.length,\n () => `input rank is ${$x.rank} but should be > than blockShape.length ${\n blockShape.length}`);\n\n util.assert(\n crops.length === blockShape.length,\n () => `crops.length is ${\n crops.length} but should be equal to blockShape.length ${\n blockShape.length}`);\n\n util.assert(\n $x.shape[0] % prod === 0,\n () => `input tensor batch is ${\n $x.shape[0]} but is not divisible by the product of ` +\n `the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);\n\n const grad = (dy: T) => {\n return {$x: () => dy.spaceToBatchND(blockShape, crops)};\n };\n\n return ENGINE.runKernelFunc(\n backend => backend.batchToSpaceND($x, blockShape, crops), {$x}, grad);\n}\n\n/**\n * This operation divides \"spatial\" dimensions `[1, ..., M]` of the input into\n * a grid of blocks of shape `blockShape`, and interleaves these blocks with\n * the \"batch\" dimension (0) such that in the output, the spatial\n * dimensions `[1, ..., M]` correspond to the position within the grid,\n * and the batch dimension combines both the position within a spatial block\n * and the original batch position. Prior to division into blocks,\n * the spatial dimensions of the input are optionally zero padded\n * according to `paddings`. See below for a precise description.\n *\n * ```js\n * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);\n * const blockShape = [2, 2];\n * const paddings = [[0, 0], [0, 0]];\n *\n * x.spaceToBatchND(blockShape, paddings).print();\n * ```\n *\n * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +\n * remainingShape`, where spatialShape has `M` dimensions.\n * @param blockShape A 1-D array. Must have shape `[M]`, all values must\n * be >= 1.\n * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=\n * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad\n * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It\n * is required that\n * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`\n *\n * This operation is equivalent to the following steps:\n *\n * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input\n * according to `paddings` to produce `padded` of shape paddedShape.\n *\n * 2. Reshape `padded` to `reshapedPadded` of shape:\n * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,\n * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`\n *\n * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`\n * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,\n * paddedShape[M] / blockShape[M-1]] + remainingShape`\n *\n * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the\n * batch dimension, producing an output tensor of shape:\n * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,\n * paddedShape[M] / blockShape[M-1]] + remainingShape`\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction spaceToBatchND_(\n x: T|TensorLike, blockShape: number[], paddings: number[][]): T {\n const $x = convertToTensor(x, 'x', 'spaceToBatchND');\n\n util.assert(\n $x.rank >= 1 + blockShape.length,\n () => `input rank ${$x.rank} should be > than [blockShape] ${\n blockShape.length}`);\n\n util.assert(\n paddings.length === blockShape.length,\n () => `paddings.shape[0] ${\n paddings.length} must be equal to [blockShape] ${blockShape.length}`);\n\n util.assert(\n $x.shape.reduce(\n (a, b, i) => {\n if (i > 0 && i <= blockShape.length) {\n return a &&\n ((b + paddings[i - 1][0] + paddings[i - 1][1]) %\n blockShape[i - 1] ===\n 0);\n }\n return a;\n },\n true),\n () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${\n paddings.toString()} must be divisible by blockShapes ${\n blockShape.toString()}`);\n\n const grad = (dy: T) => {\n return {$x: () => dy.batchToSpaceND(blockShape, paddings)};\n };\n\n return ENGINE.runKernelFunc(\n backend => backend.spaceToBatchND($x, blockShape, paddings), {$x}, grad);\n}\n\n/**\n * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.\n *\n * ```js\n * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * tf.unstack(a).forEach(tensor => tensor.print());\n * ```\n *\n * @param x A tensor object.\n * @param axis The axis to unstack along. Defaults to 0 (the first dim).\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] {\n axis = axis || 0;\n const $x = convertToTensor(x, 'x', 'unstack');\n util.assert(\n axis >= -$x.shape.length && axis < $x.shape.length,\n () =>\n `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);\n if (axis < 0) {\n axis += $x.shape.length;\n }\n const grad = (dy: Tensor[]) => {\n return {x: () => stack(dy, axis)};\n };\n const attrs = {axis};\n return ENGINE.runKernelFunc(\n backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs);\n}\n\n/**\n * Computes the cumulative sum of a `tf.Tensor` along `axis`.\n *\n * ```js\n * const x = tf.tensor([1, 2, 3, 4]);\n * x.cumsum().print();\n * ```\n * ```js\n * const x = tf.tensor([[1, 2], [3, 4]]);\n * x.cumsum().print();\n * ```\n *\n * @param x The input tensor to be summed.\n * @param axis The axis along which to sum. Optional. Defaults to 0.\n * @param exclusive Whether to perform exclusive cumulative sum. Optional.\n * Defaults to false. If set to true then the sum of each tensor entry\n * does not include its own value, but only the values previous to it\n * along the specified axis.\n * @param reverse Whether to sum in the opposite direction. Optional.\n * Defaults to false.\n */\n/** @doc {heading: 'Operations', subheading: 'Scan'} */\nfunction cumsum_(\n x: Tensor|TensorLike, axis = 0, exclusive = false, reverse = false): T {\n const $x = convertToTensor(x, 'x', 'cumsum');\n\n axis = axis | 0;\n const permutation = getAxesPermutation([axis], $x.rank);\n let permutedX = $x;\n if (permutation != null) {\n permutedX = $x.transpose(permutation);\n }\n const permutedAxis = getInnerMostAxes(1, $x.rank)[0];\n\n const grad = (dy: T) => {\n return {permutedX: () => dy.cumsum(axis, exclusive, !reverse)};\n };\n let value = ENGINE.runKernelFunc(\n backend => backend.cumsum(\n permutedX, permutedAxis, exclusive, reverse),\n {permutedX}, grad) as T;\n\n if (permutation != null) {\n value = value.transpose(permutation);\n }\n return value;\n}\n\n/**\n * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension\n * into the tensor's shape.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n * const axis = 1;\n * x.expandDims(axis).print();\n * ```\n *\n * @param x The input tensor whose dimensions to be expanded.\n * @param axis The dimension index at which to insert shape of `1`. Defaults\n * to 0 (the first dimension).\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction expandDims_(\n x: Tensor|TensorLike, axis = 0): Tensor {\n const parseAs: DataType = null;\n const $x = convertToTensor(x, 'x', 'expandDims', parseAs);\n\n util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');\n const newShape = $x.shape.slice();\n if (axis < 0) {\n // Negative value is counted from the tail of rank.\n util.assert(\n -($x.rank + 1) <= axis,\n () => `Axis must be in the interval [${- ($x.rank + 1)}, ${$x.rank}]`);\n axis = $x.rank + axis + 1;\n }\n newShape.splice(axis, 0, 1);\n return reshape($x, newShape as ShapeMap[R2]);\n}\n\n/**\n * Rearranges data from depth into blocks of spatial data. More specifically,\n * this op outputs a copy of the input tensor where values from the `depth`\n * dimension are moved in spatial blocks to the `height` and `width` dimensions.\n * The attr `blockSize` indicates the input block size and how the data is\n * moved.\n *\n * - Chunks of data of size `blockSize * blockSize` from depth are rearranged\n * into non-overlapping blocks of size `blockSize x blockSize`\n *\n * - The width the output tensor is `inputWidth * blockSize`, whereas the\n * height is `inputHeight * blockSize`\n *\n * - The Y, X coordinates within each block of the output image are determined\n * by the high order component of the input channel index\n *\n * - The depth of the input tensor must be divisible by `blockSize *\n * blockSize`\n *\n * The `dataFormat` attr specifies the layout of the input and output tensors\n * with the following options: \"NHWC\": [ `batch, height, width, channels` ]\n * \"NCHW\": [ `batch, channels, height, width` ]\n *\n * ```js\n * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);\n * const blockSize = 2;\n * const dataFormat = \"NHWC\";\n *\n * tf.depthToSpace(x, blockSize, dataFormat).print();\n * ```\n *\n * @param x The input tensor of rank 4\n * @param blockSIze An `int` that is `>= 2`. The size of the spatial block\n * @param dataFormat An optional string from: \"NHWC\", \"NCHW\". Defaults to \"NHWC\"\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction depthToSpace_(\n x: Tensor4D|TensorLike4D, blockSize: number,\n dataFormat: 'NHWC'|'NCHW' = 'NHWC'): Tensor4D {\n const $x = convertToTensor(x, 'x', 'depthToSpace') as Tensor4D;\n\n const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];\n const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];\n const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];\n\n util.assert(\n inputHeight * blockSize >= 0,\n () => `Negative dimension size caused by overflow when multiplying\n ${inputHeight} and ${blockSize} for depthToSpace with input shape\n ${$x.shape}`);\n\n util.assert(\n inputWidth * blockSize >= 0,\n () => `Negative dimension size caused by overflow when multiplying\n ${inputWidth} and ${blockSize} for depthToSpace with input shape\n ${$x.shape}`);\n\n util.assert(\n (inputDepth % (blockSize * blockSize) === 0),\n () => `Dimension size must be evenly divisible by ${\n blockSize * blockSize} but is ${\n inputDepth} for depthToSpace with input shape ${$x.shape}`);\n\n return ENGINE.runKernelFunc(\n backend => backend.depthToSpace($x, blockSize, dataFormat), {$x});\n}\n\n/**\n * Computes the difference between two lists of numbers.\n *\n * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`\n * that represents all values that are in `x` but not in `y`. The returned\n * Tensor `out` is sorted in the same order that the numbers appear in `x`\n * (duplicates are preserved). This operation also returns a Tensor indices that\n * represents the position of each out element in `x`. In other words:\n *\n * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`\n *\n * ```js\n * const x = [1, 2, 3, 4, 5, 6];\n * const y = [1, 3, 5];\n *\n * const [out, indices] = await tf.setdiff1dAsync(x, y);\n * out.print(); // [2, 4, 6]\n * indices.print(); // [1, 3, 5]\n * ```\n *\n * @param x 1-D Tensor. Values to keep.\n * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the\n * output.\n * @returns Promise of Tensor tuple [out, indices].\n * out: Tensor with the same type as x.\n * indices: A Tensor of type int32.\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nasync function setdiff1dAsync_(\n x: Tensor|TensorLike, y: Tensor|TensorLike): Promise<[Tensor, Tensor]> {\n const $x = convertToTensor(x, 'x', 'setdiff1d');\n const $y = convertToTensor(y, 'y', 'setdiff1d');\n\n util.assert(\n $x.dtype === $y.dtype,\n () => `x and y should have the same dtype, but got x (${\n $x.dtype}) and y (${$y.dtype}).`);\n\n util.assert(\n $x.rank === 1, () => `x should be 1D tensor, but got x (${$x.shape}).`);\n\n util.assert(\n $y.rank === 1, () => `y should be 1D tensor, but got y (${$y.shape}).`);\n\n const xVals = await $x.data();\n const yVals = await $y.data();\n const ySet = new Set(yVals);\n\n let outputSize = 0;\n for (let i = 0; i < xVals.length; i++) {\n if (!ySet.has(xVals[i])) {\n outputSize++;\n }\n }\n\n const buffer = new TensorBuffer([outputSize], $x.dtype);\n const indices = new TensorBuffer([outputSize], 'int32');\n for (let i = 0, p = 0; i < xVals.length; i++) {\n if (!ySet.has(xVals[i])) {\n buffer.values[p] = xVals[i];\n indices.values[p] = i;\n p++;\n }\n }\n return [buffer.toTensor(), indices.toTensor()];\n}\n\n/**\n * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.\n *\n * The values are stored in CPU as `TypedArray`. Fill the buffer using\n * `buffer.set()`, or by modifying directly `buffer.values`.\n *\n * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with\n * those values.\n *\n * ```js\n * // Create a buffer and set values at particular indices.\n * const buffer = tf.buffer([2, 2]);\n * buffer.set(3, 0, 0);\n * buffer.set(5, 1, 0);\n *\n * // Convert the buffer back to a tensor.\n * buffer.toTensor().print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param dtype The dtype of the buffer. Defaults to 'float32'.\n * @param values The values of the buffer as `TypedArray`. Defaults to\n * zeros.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nexport function buffer(\n shape: ShapeMap[R], dtype: D = 'float32' as D,\n values?: DataTypeMap[D]): TensorBuffer {\n dtype = dtype || 'float32' as D;\n util.assertNonNegativeIntegerDimensions(shape);\n return new TensorBuffer(shape, dtype, values);\n}\n\n/**\n * Prints information about the `tf.Tensor` including its data.\n *\n * ```js\n * const verbose = true;\n * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);\n * ```\n * @param x The tensor to be printed.\n * @param verbose Whether to print verbose information about the ` Tensor`,\n * including dtype and size.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction print(x: T, verbose = false): void {\n console.log(x.toString(verbose));\n}\n\nexport {\n print // Not wrapped in op() since no need to increase stack trace.\n};\n\nexport const batchToSpaceND = op({batchToSpaceND_});\nexport const cast = op({cast_});\nexport const cumsum = op({cumsum_});\nexport const depthToSpace = op({depthToSpace_});\nexport const expandDims = op({expandDims_});\nexport const reshape = op({reshape_});\nexport const spaceToBatchND = op({spaceToBatchND_});\nexport const squeeze = op({squeeze_});\nexport const stack = op({stack_});\nexport const unstack = op({unstack_});\nexport const setdiff1dAsync = setdiff1dAsync_;\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Gets the new shape of the input Tensor after it's been reshaped\n * to:\n * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),\n * inputShape[1], ..., inputShape[N-1]]\n *\n * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd\n */\nexport function getReshaped(\n inputShape: number[], blockShape: number[], prod: number,\n batchToSpace = true): number[] {\n let reshaped: number[] = [];\n if (batchToSpace) {\n reshaped = reshaped.concat(blockShape.slice(0));\n reshaped.push(inputShape[0] / prod);\n reshaped = reshaped.concat(inputShape.slice(1));\n } else {\n reshaped = reshaped.concat(inputShape[0]);\n const spatialLength = blockShape.length;\n for (let i = 0; i < spatialLength; ++i) {\n reshaped =\n reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);\n }\n reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));\n }\n return reshaped;\n}\n\n/**\n * Gets the permutation that will transpose the dimensions of the\n * reshaped tensor to shape:\n *\n * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,\n * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]\n *\n * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd\n */\nexport function getPermuted(\n reshapedRank: number, blockShapeRank: number,\n batchToSpace = true): number[] {\n const permuted = [];\n if (batchToSpace) {\n permuted.push(blockShapeRank);\n for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {\n if (i <= 2 * blockShapeRank) {\n permuted.push(i);\n permuted.push(i - (blockShapeRank + 1));\n } else {\n permuted.push(i);\n }\n }\n } else {\n const permutedBeforeBatch = [];\n const permutedAfterBatch = [];\n for (let i = 1; i < reshapedRank; ++i) {\n if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {\n permutedAfterBatch.push(i);\n } else {\n permutedBeforeBatch.push(i);\n }\n }\n permuted.push(...permutedBeforeBatch);\n permuted.push(0);\n permuted.push(...permutedAfterBatch);\n }\n return permuted;\n}\n\n/**\n * Gets the shape of the reshaped and permuted input Tensor before any cropping\n * is applied. The new shape will be:\n *\n * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,\n * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]\n *\n * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd\n */\nexport function getReshapedPermuted(\n inputShape: number[], blockShape: number[], prod: number,\n batchToSpace = true): number[] {\n const reshapedPermuted = [];\n\n if (batchToSpace) {\n reshapedPermuted.push(inputShape[0] / prod);\n } else {\n reshapedPermuted.push(inputShape[0] * prod);\n }\n\n for (let i = 1; i < inputShape.length; ++i) {\n if (i <= blockShape.length) {\n if (batchToSpace) {\n reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);\n } else {\n reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);\n }\n } else {\n reshapedPermuted.push(inputShape[i]);\n }\n }\n\n return reshapedPermuted;\n}\n\n/**\n * Converts the crops argument into the beginning coordinates of a slice\n * operation.\n */\nexport function getSliceBeginCoords(\n crops: number[][], blockShape: number): number[] {\n const sliceBeginCoords = [0];\n for (let i = 0; i < blockShape; ++i) {\n sliceBeginCoords.push(crops[i][0]);\n }\n return sliceBeginCoords;\n}\n\n/**\n * Converts the crops argument into the size of a slice operation. When\n * combined with getSliceBeginCoords this function allows the reshaped and\n * permuted Tensor to be cropped to its final output shape of:\n *\n * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,\n * inputShape[M] * blockShape[M-1] -crops[M-1,0] -\n * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]\n *\n * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd\n */\nexport function getSliceSize(\n uncroppedShape: number[], crops: number[][], blockShape: number): number[] {\n const sliceSize = uncroppedShape.slice(0, 1);\n for (let i = 0; i < blockShape; ++i) {\n sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);\n }\n\n return sliceSize;\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n// Allow UpperCamelCase variable names\n// tslint:disable: variable-name\n// Unfortunately just enabling PascalCase per file (tslint:enable:\n// allow-pascal-case) doesn't work.\nimport {NamedTensorInfoMap, TensorInfo} from './kernel_registry';\nimport {PixelData} from './types';\n\nexport const Add = 'Add';\nexport type AddInputs = BinaryInputs;\n\nexport const AddN = 'AddN';\nexport type AddNInputs = TensorInfo[];\n\nexport type BinaryInputs = Pick;\n\nexport const Div = 'Div';\nexport type DivInputs = BinaryInputs;\n\nexport const FusedBatchNorm = 'FusedBatchNorm';\nexport type FusedBatchNormInputs =\n Pick;\nexport interface FusedBatchNormAttrs {\n varianceEpsilon: number;\n}\n\nexport const SquaredDifference = 'SquaredDifference';\nexport type SquaredDifferenceInputs = BinaryInputs;\n\nexport const Square = 'Square';\nexport type SquareInputs = Pick;\n\nexport const Transpose = 'Transpose';\nexport type TransposeInputs = Pick;\nexport interface TransposeAttrs {\n perm: number[];\n}\n\nexport const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';\nexport type NonMaxSuppressionV5Inputs =\n Pick;\nexport interface NonMaxSuppressionV5Attrs {\n maxOutputSize: number;\n iouThreshold: number;\n scoreThreshold: number;\n softNmsSigma: number;\n}\n\nexport const BroadcastTo = 'BroadcastTo';\nexport type BroadcastToInputs = Pick;\nexport interface BroadCastToAttrs {\n shape: number[];\n inputShape: number[]; // for gradient\n}\n\nexport const OneHot = 'OneHot';\nexport type OneHotInputs = Pick;\nexport interface OneHotAttrs {\n depth: number;\n onValue: number;\n offValue: number;\n}\n\nexport const Identity = 'Identity';\nexport type IdentityInputs = Pick;\n\nexport const Tile = 'Tile';\nexport type TileInputs = Pick;\nexport interface TileAttrs {\n reps: number[];\n}\n\nexport const PadV2 = 'PadV2';\nexport type PadV2Inputs = Pick;\nexport interface PadV2Attrs {\n paddings: Array<[number, number]>;\n constantValue: number;\n}\n\n/**\n * TensorFlow.js-only kernels\n */\nexport const FromPixels = 'FromPixels';\nexport interface FromPixelsInputs {\n pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|\n HTMLVideoElement;\n}\nexport interface FromPixelsAttrs {\n numChannels: number;\n}\n\nexport const MaxPoolWithArgmax = 'MaxPoolWithArgmax';\nexport type MaxPoolWithArgmaxInputs = Pick;\nexport interface MaxPoolWithArgmaxAttrs {\n filterSize: [number, number]|number;\n strides: [number, number]|number;\n pad: 'valid'|'same'|number;\n includeBatchInIndex: boolean;\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {Add, AddInputs} from '../kernel_names';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {op} from './operation';\n\n/**\n * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.\n *\n * We also expose `tf.addStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3, 4]);\n * const b = tf.tensor1d([10, 20, 30, 40]);\n *\n * a.add(b).print(); // or tf.add(a, b)\n * ```\n *\n * ```js\n * // Broadcast add a with b.\n * const a = tf.scalar(5);\n * const b = tf.tensor1d([10, 20, 30, 40]);\n *\n * a.add(b).print(); // or tf.add(a, b)\n * ```\n * @param a The first `tf.Tensor` to add.\n * @param b The second `tf.Tensor` to add. Must have the same type as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction add_(a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'add');\n let $b = convertToTensor(b, 'b', 'add');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const forward: ForwardFunc = (backend, save) => {\n const res = backend.add($a, $b);\n save([$a, $b]);\n return res;\n };\n\n const inputs: AddInputs = {a: $a, b: $b};\n\n return ENGINE.runKernelFunc(\n forward, inputs as {} as NamedTensorMap, null /* gradient */,\n Add) as T;\n}\n\nexport const add = op({add_});\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Returns the dimensions in the input shape that are broadcasted to\n * produce the provided output shape.\n *\n * The returned dimensions are 0-indexed and sorted. An example:\n * inShape = [4, 1, 3]\n * outShape = [5, 4, 3, 3]\n * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.\n */\nexport function getBroadcastDims(\n inShape: number[], outShape: number[]): number[] {\n const inRank = inShape.length;\n const dims: number[] = [];\n for (let i = 0; i < inRank; i++) {\n const dim = inRank - 1 - i;\n const a = inShape[dim] || 1;\n const b = outShape[outShape.length - 1 - i] || 1;\n if (b > 1 && a === 1) {\n dims.unshift(dim);\n }\n }\n return dims;\n}\n\n/**\n * Returns the axes in the output space that should be reduced to produce\n * the input space.\n */\nexport function getReductionAxes(\n inShape: number[], outShape: number[]): number[] {\n const result: number[] = [];\n for (let i = 0; i < outShape.length; i++) {\n const inDim = inShape[inShape.length - i - 1];\n const outAxis = outShape.length - i - 1;\n const outDim = outShape[outAxis];\n if (inDim == null || (inDim === 1 && outDim > 1)) {\n result.unshift(outAxis);\n }\n }\n return result;\n}\n\nexport function assertAndGetBroadcastShape(\n shapeA: number[], shapeB: number[]): number[] {\n const result: number[] = [];\n const l = Math.max(shapeA.length, shapeB.length);\n\n for (let i = 0; i < l; i++) {\n let a = shapeA[shapeA.length - i - 1];\n if (a == null) {\n a = 1;\n }\n let b = shapeB[shapeB.length - i - 1];\n if (b == null) {\n b = 1;\n }\n if (a === 1) {\n result.unshift(b);\n } else if (b === 1) {\n result.unshift(a);\n } else if (a !== b) {\n const errMsg = `Operands could not be broadcast together with shapes ` +\n `${shapeA} and ${shapeB}.`;\n throw Error(errMsg);\n } else {\n result.unshift(a);\n }\n }\n return result;\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {op} from './operation';\nimport {scalar, zerosLike} from './tensor_ops';\n\n/**\n * Computes `-1 * x` element-wise.\n *\n * ```js\n * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);\n *\n * x.neg().print(); // or tf.neg(x)\n * ```\n *\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction neg_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'neg');\n\n const grad = (dy: T) => {\n return {x: () => dy.neg()};\n };\n\n const attrs = {};\n const inputsToSave = [$x];\n return ENGINE.runKernelFunc(\n backend => backend.neg($x), {x: $x}, grad, 'Neg', attrs, inputsToSave);\n}\n\n/**\n * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`\n *\n * ```js\n * const x = tf.tensor1d([.6, 1.1, -3.3]);\n *\n * x.ceil().print(); // or tf.ceil(x)\n * ```\n * @param x The input Tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction ceil_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'ceil');\n\n // TODO(manrajgrover): Return null for gradients when backprop supports it.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.ceil($x), {$x}, grad);\n}\n\n/**\n * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.\n *\n * ```js\n * const x = tf.tensor1d([.6, 1.1, -3.3]);\n *\n * x.floor().print(); // or tf.floor(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction floor_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'floor');\n\n // TODO(nsthorat): Let gradients be null for cases where we want to stop\n // backpropgation.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.floor($x), {$x}, grad);\n}\n\n/**\n * Returns an element-wise indication of the sign of a number.\n *\n * ```js\n * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);\n *\n * x.sign().print(); // or tf.sign(x)\n * ```\n * @param x The input Tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction sign_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'sign');\n\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.sign($x), {$x}, grad);\n}\n\n/**\n * RReturns which elements of x are NaN.\n *\n * ```js\n * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);\n *\n * x.isNaN().print(); // or tf.isNaN(x)\n * ```\n * @param x The input Tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction isNaN_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'isNaN');\n\n // TODO(nsthorat): Let gradients be null for cases where we want to stop\n // backpropgation.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.isNaN($x), {$x}, grad);\n}\n\n/**\n * Returns which elements of x are Infinity or -Infinity.\n *\n * ```js\n * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);\n *\n * x.isInf().print(); // or tf.isNaN(x)\n * ```\n * @param x The input Tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction isInf_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'isInf');\n\n // TODO(nsthorat): Let gradients be null for cases where we want to stop\n // backpropgation.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.isInf($x), {$x}, grad);\n}\n\n/**\n * Returns which elements of x are finite.\n *\n * ```js\n * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);\n *\n * x.isFinite().print(); // or tf.isNaN(x)\n * ```\n * @param x The input Tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction isFinite_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'isFinite');\n\n // TODO(nsthorat): Let gradients be null for cases where we want to stop\n // backpropgation.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.isFinite($x), {$x}, grad);\n}\n\n/**\n * Computes round of input `tf.Tensor` element-wise: `round(x)`.\n * It implements banker's rounding.\n *\n * ```js\n * const x = tf.tensor1d([.6, 1.1, -3.3]);\n *\n * x.round().print(); // or tf.round(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction round_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'round');\n\n // TODO(nsthorat): Let gradients be null for cases where we want to stop\n // backpropgation.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.round($x), {$x}, grad);\n}\n\n/**\n * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, -3]);\n *\n * x.exp().print(); // or tf.exp(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction exp_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'exp');\n\n const bck = (dy: T, saved: Tensor[]) => {\n return {x: () => dy.mulStrict(saved[0] as T)};\n };\n const attrs = {};\n const inputsToSave: Tensor[] = [];\n const outputsToSave = [true];\n return ENGINE.runKernelFunc((backend, save) => {\n const y = backend.exp($x);\n save([y]);\n return y;\n }, {x: $x}, bck, 'Exp', attrs, inputsToSave, outputsToSave);\n}\n\n/**\n * Computes exponential of the input `tf.Tensor` minus one element-wise.\n * `e ^ x - 1`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, -3]);\n *\n * x.expm1().print(); // or tf.expm1(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction expm1_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'expm1');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.mul($x.exp())} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.expm1($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, Math.E]);\n *\n * x.log().print(); // or tf.log(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction log_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'log');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => dy.div($x.toFloat())} as {x: () => T};\n };\n\n const attrs = {};\n const inputsToSave = [$x];\n\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.log($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Log', attrs, inputsToSave);\n}\n\n/**\n * Computes natural logarithm of the input `tf.Tensor` plus one\n * element-wise: `ln(1 + x)`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, Math.E - 1]);\n *\n * x.log1p().print(); // or tf.log1p(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction log1p_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'log1p');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.div($x.add(1))} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.log1p($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 4, -1]);\n *\n * x.sqrt().print(); // or tf.sqrt(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction sqrt_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'sqrt');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.div($x.toFloat().sqrt().mul(2))} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.sqrt($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes reciprocal of square root of the input `tf.Tensor` element-wise:\n * `y = 1 / sqrt(x)`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 4, -1]);\n *\n * x.rsqrt().print(); // or tf.rsqrt(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction rsqrt_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'rsqrt');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => dy.div($x.pow(1.5).mul(2)).neg() as T};\n };\n const inputsToSave = [$x];\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.rsqrt($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Rsqrt', {} /* attrs */, inputsToSave);\n}\n\n/**\n * Computes reciprocal of x element-wise: `1 / x`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, 2]);\n *\n * x.reciprocal().print(); // or tf.reciprocal(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction reciprocal_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'reciprocal');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.div($x.square().neg())} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.reciprocal($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes absolute value element-wise: `abs(x)`\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 4]);\n *\n * x.abs().print(); // or tf.abs(x)\n * ```\n * @param x The input `tf.Tensor`.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction abs_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'abs');\n\n if ($x.dtype === 'complex64') {\n return ENGINE.runKernelFunc(backend => backend.complexAbs($x), {$x});\n }\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => dy.mul($x.toFloat().step(-1))} as {x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.abs($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Abs');\n}\n\n/**\n * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 4]);\n *\n * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)\n * ```\n * @param x The input tensor.\n * @param clipValueMin Lower-bound of range to be clipped to.\n * @param clipValueMax Upper-bound of range to be clipped to.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction clipByValue_(\n x: T|TensorLike, clipValueMin: number, clipValueMax: number): T {\n const $x = convertToTensor(x, 'x', 'clipByValue');\n util.assert(\n (clipValueMin <= clipValueMax),\n () => `Error in clip: min (${clipValueMin}) must be ` +\n `less than or equal to max (${clipValueMax}).`);\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {\n x: () => dy.where(\n $x.greaterEqual(clipValueMin)\n .logicalAnd($x.lessEqual(clipValueMax)),\n zerosLike(dy)) as T,\n };\n };\n const inputsToSave = [$x];\n const attr = {min: clipValueMin, max: clipValueMax};\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.clip($x, clipValueMin, clipValueMax);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'ClipByValue', attr, inputsToSave);\n}\n\n/**\n * Computes sigmoid element-wise, `1 / (1 + exp(-x))`\n *\n * ```js\n * const x = tf.tensor1d([0, -1, 2, -3]);\n *\n * x.sigmoid().print(); // or tf.sigmoid(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction sigmoid_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'sigmoid');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [y] = saved;\n return {x: () => dy.mul(y.mul(scalar(1).sub(y)))} as {x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const y = backend.sigmoid($x);\n save([y]);\n return y;\n }, {x: $x}, grad, 'Sigmoid');\n}\n\n/**\n * Computes log sigmoid of the input `tf.Tensor` element-wise:\n * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.logSigmoid().print(); // or tf.logSigmoid(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction logSigmoid_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'logSigmoid');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.mul($x.neg().sigmoid())} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.softplus($x.neg()).neg();\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.softplus().print(); // or tf.softplus(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction softplus_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'softplus');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.mul($x.sigmoid())} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.softplus($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes sin of the input Tensor element-wise: `sin(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);\n *\n * x.sin().print(); // or tf.sin(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction sin_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'sin');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => $x.toFloat().cos().mul(dy)} as {x: () => T};\n };\n const inputsToSave = [$x];\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.sin($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Sin', {} /* attrs */, inputsToSave);\n}\n\n/**\n * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);\n *\n * x.cos().print(); // or tf.cos(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction cos_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'cos');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => $x.toFloat().sin().neg().mul(dy)} as {x: () => T};\n };\n const inputsToSave = [$x];\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.cos($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Cos', {} /* attrs */, inputsToSave);\n}\n\n/**\n * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);\n *\n * x.tan().print(); // or tf.tan(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction tan_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'tan');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.div($x.cos().square())} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.tan($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.asin().print(); // or tf.asin(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction asin_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'asin');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {\n $x: () => dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt() as T)\n };\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.asin($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.acos().print(); // or tf.acos(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction acos_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'acos');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {\n $x: () =>\n dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt() as T).neg()\n };\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.acos($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.atan().print(); // or tf.atan(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction atan_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'atan');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.div($x.toFloat().square().add(1))} as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.atan($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.sinh().print(); // or tf.sinh(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction sinh_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'sinh');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => $x.toFloat().cosh().mulStrict(dy) as T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.sinh($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.cosh().print(); // or tf.cosh(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction cosh_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'cosh');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => $x.toFloat().sinh().mulStrict(dy) as T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.cosh($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, 70]);\n *\n * x.tanh().print(); // or tf.tanh(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction tanh_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'tanh');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [y] = saved;\n return {x: () => scalar(1).sub(y.square()).mulStrict(dy) as T};\n };\n const outputsToSave = [true];\n return ENGINE.runKernelFunc(\n (backend, save) => {\n const y = backend.tanh($x);\n save([y]);\n return y;\n },\n {x: $x}, grad, 'Tanh', {} /* attrs */, null /* inputsToSave */,\n outputsToSave);\n}\n\n/**\n * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:\n * `asinh(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, 1, -1, .7]);\n *\n * x.asinh().print(); // or tf.asinh(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction asinh_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'asinh');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {\n $x: () => dy.divStrict(scalar(1).add($x.toFloat().square()).sqrt() as T)\n };\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.asinh($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:\n * `acosh(x)`\n *\n * ```js\n * const x = tf.tensor1d([10, 1, 3, 5.7]);\n *\n * x.acosh().print(); // or tf.acosh(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction acosh_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'acosh');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.divStrict($x.toFloat().square().sub(1).sqrt() as T)};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.acosh($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:\n * `atanh(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, .1, -.1, .7]);\n *\n * x.atanh().print(); // or tf.atanh(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction atanh_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'atanh');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => dy.div(scalar(1).sub($x.toFloat().square()))} as\n {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.atanh($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes gause error function of the input `tf.Tensor` element-wise:\n * `erf(x)`\n *\n * ```js\n * const x = tf.tensor1d([0, .1, -.1, .7]);\n *\n * x.erf().print(); // or tf.erf(x);\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction erf_(x: T|TensorLike): T {\n let $x = convertToTensor(x, 'x', 'erf');\n util.assert(\n $x.dtype === 'int32' || $x.dtype === 'float32',\n () => 'Input dtype must be `int32` or `float32`.');\n\n if ($x.dtype === 'int32') {\n $x = $x.toFloat();\n }\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {\n $x: () => dy.mul($x.square().neg().exp().mul(2 / Math.sqrt(Math.PI)))\n } as {$x: () => T};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.erf($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x`\n *\n * ```js\n * const x = tf.tensor1d([0, 2, -1, -3]);\n *\n * x.step(.5).print(); // or tf.step(x, .5)\n * ```\n * @param x The input tensor.\n * @param alpha The gradient when input is negative.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction step_(x: T|TensorLike, alpha = 0.0): T {\n const $x = convertToTensor(x, 'x', 'step');\n\n // TODO(manrajgrover): Return null for gradients when backprop supports\n // it.\n const grad = (dy: T) => {\n return {$x: () => zerosLike(dy)};\n };\n return ENGINE.runKernelFunc(backend => backend.step($x, alpha), {$x}, grad);\n}\n\nexport const abs = op({abs_});\nexport const acos = op({acos_});\nexport const acosh = op({acosh_});\nexport const asin = op({asin_});\nexport const asinh = op({asinh_});\nexport const atan = op({atan_});\nexport const atanh = op({atanh_});\nexport const ceil = op({ceil_});\nexport const clipByValue = op({clipByValue_});\nexport const cos = op({cos_});\nexport const cosh = op({cosh_});\nexport const erf = op({erf_});\nexport const exp = op({exp_});\nexport const expm1 = op({expm1_});\nexport const floor = op({floor_});\nexport const log = op({log_});\nexport const log1p = op({log1p_});\nexport const logSigmoid = op({logSigmoid_});\nexport const neg = op({neg_});\nexport const reciprocal = op({reciprocal_});\nexport const round = op({round_});\nexport const rsqrt = op({rsqrt_});\nexport const sigmoid = op({sigmoid_});\nexport const sign = op({sign_});\nexport const isNaN = op({isNaN_});\nexport const isInf = op({isInf_});\nexport const isFinite = op({isFinite_});\nexport const sin = op({sin_});\nexport const sinh = op({sinh_});\nexport const softplus = op({softplus_});\nexport const sqrt = op({sqrt_});\nexport const step = op({step_});\nexport const tan = op({tan_});\nexport const tanh = op({tanh_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {add} from './add';\nimport * as broadcast_util from './broadcast_util';\nimport {op} from './operation';\nimport {scalar, zerosLike} from './tensor_ops';\nimport {neg} from './unary_ops';\n\n/**\n * Adds two `tf.Tensor`s element-wise, A + B.\n *\n * Inputs must be the same shape. For broadcasting support, use add() instead.\n *\n * @param a The first Tensor to add element-wise.\n * @param b The second Tensor to add element-wise.\n */\nfunction addStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'addStrict');\n const $b = convertToTensor(b, 'b', 'addStrict');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: ');\n return $a.add($b);\n}\n\n/**\n * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.\n *\n * We also expose `tf.subStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([10, 20, 30, 40]);\n * const b = tf.tensor1d([1, 2, 3, 4]);\n *\n * a.sub(b).print(); // or tf.sub(a, b)\n * ```\n *\n * ```js\n * // Broadcast subtract a with b.\n * const a = tf.tensor1d([10, 20, 30, 40]);\n * const b = tf.scalar(5);\n *\n * a.sub(b).print(); // or tf.sub(a, b)\n * ```\n * @param a The first `tf.Tensor` to subtract from.\n * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as\n * `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction sub_(a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'sub');\n let $b = convertToTensor(b, 'b', 'sub');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const outShape =\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n\n const der = (dy: Tensor) => {\n const derA = () => {\n let res = dy;\n const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape($a.shape);\n };\n const derB = () => {\n let res = dy;\n const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.neg().reshape($b.shape);\n };\n return {a: derA, b: derB};\n };\n return ENGINE.runKernelFunc(\n backend => backend.subtract($a, $b), {a: $a, b: $b}, der, 'Sub') as\n T;\n}\n\n/**\n * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must\n * be the same shape.\n *\n * For broadcasting support, use `tf.sub` instead.\n *\n * @param a The first Tensor to subtract element-wise.\n * @param b The second Tensor to subtract element-wise.\n */\nfunction subStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'subStrict');\n const $b = convertToTensor(b, 'b', 'subStrict');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: ');\n return $a.sub($b);\n}\n\n/**\n * Computes the power of one `tf.Tensor` to another. Supports broadcasting.\n *\n * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for\n * corresponding elements in x and y. The result's dtype will be the upcasted\n * type of the `base` and `exp` dtypes.\n *\n * ```js\n * const a = tf.tensor([[2, 3], [4, 5]])\n * const b = tf.tensor([[1, 2], [3, 0]]).toInt();\n *\n * a.pow(b).print(); // or tf.pow(a, b)\n * ```\n *\n * ```js\n * const a = tf.tensor([[1, 2], [3, 4]])\n * const b = tf.tensor(2).toInt();\n *\n * a.pow(b).print(); // or tf.pow(a, b)\n * ```\n * We also expose `powStrict` which has the same signature as this op and\n * asserts that `base` and `exp` are the same shape (does not broadcast).\n *\n * @param base The base `tf.Tensor` to pow element-wise.\n * @param exp The exponent `tf.Tensor` to pow element-wise.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction pow_(\n base: Tensor|TensorLike, exp: Tensor|TensorLike): T {\n let $base = convertToTensor(base, 'base', 'pow');\n let $exp = convertToTensor(exp, 'exp', 'pow');\n [$base, $exp] = makeTypesMatch($base, $exp);\n\n const outShape =\n broadcast_util.assertAndGetBroadcastShape($base.shape, $exp.shape);\n const grad = (dy: Tensor, saved: Tensor[]) => {\n const [$base, $exp, y] = saved;\n const derBase = () => {\n const expFloat = $exp.toFloat();\n let res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1)))));\n const reduceAxes = broadcast_util.getReductionAxes($base.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape($base.shape) as T;\n };\n const derExp = () => {\n const condition = $base.greater(0);\n const logBase = $base.log().where(condition, zerosLike($base));\n let res = dy.mul(y.mul(logBase));\n const reduceAxes = broadcast_util.getReductionAxes($exp.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape($exp.shape);\n };\n return {a: derBase, b: derExp};\n };\n\n const attrs = {};\n const inputsToSave = [$base, $exp];\n const outputsToSave = [true];\n return ENGINE.runKernelFunc((backend, save) => {\n const y = backend.pow($base, $exp);\n save([$base, $exp, y]);\n return y;\n }, {a: $base, b: $exp}, grad, 'Pow', attrs, inputsToSave, outputsToSave) as T;\n}\n\n/**\n * Computes the power of one `tf.Tensor` to another. Inputs must\n * be the same shape.\n *\n * For broadcasting support, use `tf.pow` instead.\n *\n * @param base The base tensor to pow element-wise.\n * @param exp The exponent tensor to pow element-wise.\n */\nfunction powStrict_(base: T, exp: Tensor): T {\n util.assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: ');\n return base.pow(exp);\n}\n\n/**\n * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.\n *\n * We also expose `tf.mulStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3, 4]);\n * const b = tf.tensor1d([2, 3, 4, 5]);\n *\n * a.mul(b).print(); // or tf.mul(a, b)\n * ```\n *\n * ```js\n * // Broadcast mul a with b.\n * const a = tf.tensor1d([1, 2, 3, 4]);\n * const b = tf.scalar(5);\n *\n * a.mul(b).print(); // or tf.mul(a, b)\n * ```\n * @param a The first tensor to multiply.\n * @param b The second tensor to multiply. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction mul_(a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'mul');\n let $b = convertToTensor(b, 'b', 'mul');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const outShape =\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const derA = () => {\n const res = dy.mul($b.toFloat());\n const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);\n if (reduceAxes.length > 0) {\n return res.sum(reduceAxes).reshape($a.shape);\n }\n return res;\n };\n const derB = () => {\n const res = dy.mul($a.toFloat());\n const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);\n if (reduceAxes.length > 0) {\n return res.sum(reduceAxes).reshape($b.shape);\n }\n return res;\n };\n return {a: derA, b: derB};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.multiply($a, $b);\n save([$a, $b]);\n return res;\n }, {a: $a, b: $b}, der, 'Mul') as T;\n}\n\n/**\n * Multiplies two `tf.Tensor`s element-wise, A * B.\n *\n * Inputs must be the same shape. For broadcasting support, use `tf.mul`.\n *\n * @param a The first tensor to multiply.\n * @param b The first tensor to multiply. Must have the same\n * dtype as `a`.\n */\nfunction mulStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'mul');\n const $b = convertToTensor(b, 'b', 'mul');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: ');\n return $a.mul($b);\n}\n\n/**\n * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.\n * The result is rounded with floor function.\n *\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 9, 16]);\n * const b = tf.tensor1d([1, 2, 3, 4]);\n *\n * a.floorDiv(b).print(); // or tf.div(a, b)\n * ```\n *\n * ```js\n * // Broadcast div a with b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(2);\n *\n * a.floorDiv(b).print(); // or tf.floorDiv(a, b)\n * ```\n *\n * @param a The first tensor as the numerator.\n * @param b The second tensor as the denominator. Must have the same dtype as\n * `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction floorDiv_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'floorDiv');\n let $b = convertToTensor(b, 'b', 'floorDiv');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const outShape =\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const derA = () => {\n const res = dy.div($b.toFloat());\n const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);\n if (reduceAxes.length > 0) {\n return res.sum(reduceAxes).reshape($a.shape);\n }\n return res;\n };\n const derB = () => {\n let res = dy.mul($a.toFloat());\n const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes).reshape($b.shape);\n }\n const tmp = $b.square();\n return res.div(tmp.toFloat()).neg();\n };\n return {a: derA, b: derB};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.floorDiv($a, $b);\n save([$a, $b]);\n return res;\n }, {a: $a, b: $b}, der, 'FloorDiv') as T;\n}\n\n/**\n * Divides two `tf.Tensor`s element-wise, A / B. Inputs must\n * be the same shape.\n *\n * @param a The first tensor as the numerator for element-wise division.\n * @param b The second tensor as the denominator for element-wise division.\n */\nfunction divStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'div');\n const $b = convertToTensor(b, 'b', 'div');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: ');\n return $a.div($b);\n}\n\n/**\n * Returns the mod of a and b element-wise.\n * `floor(x / y) * y + mod(x, y) = x`\n * Supports broadcasting.\n *\n * We also expose `tf.modStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 3, 16]);\n * const b = tf.tensor1d([1, 2, 9, 4]);\n *\n * a.mod(b).print(); // or tf.mod(a, b)\n * ```\n *\n * ```js\n * // Broadcast a mod b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(5);\n *\n * a.mod(b).print(); // or tf.mod(a, b)\n * ```\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same type as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction mod_(a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'mod');\n let $b = convertToTensor(b, 'b', 'mod');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const outShape =\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const derA = () => {\n const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);\n if (reduceAxes.length > 0) {\n return dy.sum(reduceAxes).reshape($a.shape);\n }\n return dy;\n };\n const derB = () => {\n const res = dy.mul($a.div($b).floor().neg());\n const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);\n if (reduceAxes.length > 0) {\n return res.sum(reduceAxes).reshape($b.shape);\n }\n return res;\n };\n return {$a: derA, $b: derB};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.mod($a, $b);\n save([$a, $b]);\n return res;\n }, {$a, $b}, der) as T;\n}\n\n/**\n * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must\n * be the same shape. For broadcasting support, use mod().\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same dtype as `a`.\n */\nfunction modStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'modStrict');\n const $b = convertToTensor(b, 'b', 'modStrict');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: ');\n return $a.mod($b);\n}\n\n/**\n * Returns the min of a and b (`a < b ? a : b`) element-wise.\n * Supports broadcasting.\n *\n * We also expose `minimumStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 3, 16]);\n * const b = tf.tensor1d([1, 2, 9, 4]);\n *\n * a.minimum(b).print(); // or tf.minimum(a, b)\n * ```\n *\n * ```js\n * // Broadcast minimum a with b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(5);\n *\n * a.minimum(b).print(); // or tf.minimum(a, b)\n * ```\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same type as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction minimum_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'minimum');\n let $b = convertToTensor(b, 'b', 'minimum');\n [$a, $b] = makeTypesMatch($a, $b);\n\n if ($a.dtype === 'bool') {\n $a = $a.toInt();\n $b = $b.toInt();\n }\n\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const derA = () => dy.mul($a.lessEqual($b).toFloat());\n const derB = () => dy.mul($a.greater($b).toFloat());\n return {a: derA, b: derB};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.minimum($a, $b);\n save([$a, $b]);\n return res;\n }, {a: $a, b: $b}, der, 'Minimum') as T;\n}\n\n/**\n * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must\n * be the same shape. For broadcasting support, use minimum().\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same dtype as `a`.\n */\nfunction minimumStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'minimumStrict');\n const $b = convertToTensor(b, 'b', 'minimumStrict');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: ');\n return $a.minimum($b);\n}\n\n/**\n * Returns the max of a and b (`a > b ? a : b`) element-wise.\n * Supports broadcasting.\n *\n * We also expose `tf.maximumStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 3, 16]);\n * const b = tf.tensor1d([1, 2, 9, 4]);\n *\n * a.maximum(b).print(); // or tf.maximum(a, b)\n * ```\n *\n * ```js\n * // Broadcast maximum a with b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(5);\n *\n * a.maximum(b).print(); // or tf.maximum(a, b)\n * ```\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same type as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction maximum_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'maximum');\n let $b = convertToTensor(b, 'b', 'maximum');\n [$a, $b] = makeTypesMatch($a, $b);\n\n if ($a.dtype === 'bool') {\n $a = $a.toInt();\n $b = $b.toInt();\n }\n\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const derA = () => dy.mul($a.greaterEqual($b).toFloat());\n const derB = () => dy.mul($a.less($b).toFloat());\n return {a: derA, b: derB};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.maximum($a, $b);\n save([$a, $b]);\n return res;\n }, {a: $a, b: $b}, der, 'Maximum') as T;\n}\n\n/**\n * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must\n * be the same shape. For broadcasting support, use maximum().\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same dtype as `a`.\n */\nfunction maximumStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'maximumStrict');\n const $b = convertToTensor(b, 'b', 'maximumStrict');\n util.assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: ');\n return $a.maximum($b);\n}\n\n/**\n * Returns (a - b) * (a - b) element-wise.\n *\n * Inputs must be the same shape. For broadcasting support, use\n * `tf.squaredDifference` instead.\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same type as `a`.\n */\nfunction squaredDifferenceStrict_(\n a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'squaredDifferenceStrict');\n const $b = convertToTensor(b, 'b', 'squaredDifferenceStrict');\n util.assertShapesMatch(\n $a.shape, $b.shape, 'Error in squaredDifferenceStrict: ');\n return $a.squaredDifference($b);\n}\n\n/**\n * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.\n * Supports broadcasting.\n *\n * ```js\n * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);\n * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);\n *\n * tf.atan2(a, b).print()\n * ```\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same dtype as `a`.\n *\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction atan2_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'atan2');\n let $b = convertToTensor(b, 'b', 'atan2');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const outShape =\n broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);\n\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const derA = () => {\n const d = add($a.square(), $b.square());\n let res = dy.mul($b.div(d));\n const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape($a.shape);\n };\n const derB = () => {\n const d = add($a.square(), $b.square());\n let res = neg(dy.mul($a.div(d)));\n const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape($b.shape);\n };\n return {$a: derA, $b: derB};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.atan2($a, $b);\n save([$a, $b]);\n return res;\n }, {$a, $b}, der) as T;\n}\n\nexport const addStrict = op({addStrict_});\nexport const atan2 = op({atan2_});\nexport const divStrict = op({divStrict_});\nexport const floorDiv = op({floorDiv_});\nexport const maximum = op({maximum_});\nexport const maximumStrict = op({maximumStrict_});\nexport const minimum = op({minimum_});\nexport const minimumStrict = op({minimumStrict_});\nexport const mod = op({mod_});\nexport const modStrict = op({modStrict_});\nexport const mul = op({mul_});\nexport const mulStrict = op({mulStrict_});\nexport const pow = op({pow_});\nexport const powStrict = op({powStrict_});\nexport const squaredDifferenceStrict = op({squaredDifferenceStrict_});\nexport const sub = op({sub_});\nexport const subStrict = op({subStrict_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {Div, DivInputs} from '../kernel_names';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {floorDiv} from './binary_ops';\nimport {op} from './operation';\n\n/**\n * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.\n *\n * We also expose `tf.divStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 9, 16]);\n * const b = tf.tensor1d([1, 2, 3, 4]);\n *\n * a.div(b).print(); // or tf.div(a, b)\n * ```\n *\n * ```js\n * // Broadcast div a with b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(2);\n *\n * a.div(b).print(); // or tf.div(a, b)\n * ```\n *\n * @param a The first tensor as the numerator.\n * @param b The second tensor as the denominator. Must have the same dtype as\n * `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'div');\n let $b = convertToTensor(b, 'b', 'div');\n [$a, $b] = makeTypesMatch($a, $b);\n\n if ($a.dtype === 'int32' && $b.dtype === 'int32') {\n return floorDiv($a, $b);\n }\n\n const forward: ForwardFunc = (backend, save) => {\n const res = backend.realDivide($a, $b);\n save([$a, $b]);\n return res;\n };\n\n const inputs: DivInputs = {a: $a, b: $b};\n const attrs = {};\n\n return ENGINE.runKernelFunc(\n forward, inputs as {} as NamedTensorMap, null /* gradient */, Div,\n attrs) as T;\n}\n\nexport const div = op({div_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor} from '../tensor';\nimport {computeStrides} from '../util';\n\n/**\n * Validate gather nd inputs.\n *\n * @param tensor The tensor contains the source values.\n * @param indices The tensor contains the indices to slice the source.\n *\n * @returns [resultShape, numUpdates, sliceSize, strides]\n */\nexport function prepareAndValidate(\n tensor: Tensor, indices: Tensor): [number[], number, number, number[]] {\n if (tensor.rank < 1) {\n throw new Error(\n 'tf.gatherND() expects the input to be rank 1 or higher,' +\n ` but the rank was ${tensor.rank}.`);\n }\n if (indices.rank < 1) {\n throw new Error(\n 'tf.gatherND() expects the indices to be rank 1 or higher,' +\n ` but the rank was ${indices.rank}.`);\n }\n if (indices.dtype !== 'int32') {\n throw new Error(\n 'tf.gatherND() expects the indices to be int32 type,' +\n ` but the dtype was ${indices.dtype}.`);\n }\n if (indices.shape[indices.rank - 1] > tensor.rank) {\n throw new Error(\n 'index innermost dimension length must be <= tensor rank; saw: ' +\n `${indices.shape[indices.rank - 1]} vs. ${tensor.rank}`);\n }\n\n if (tensor.size === 0) {\n throw new Error(\n 'Requested more than 0 entries, but input is empty.' +\n ` Input shape: ${tensor.shape}.`);\n }\n\n const indicesShape = indices.shape;\n const sliceRank = indicesShape[indicesShape.length - 1];\n\n // The result shape is\n // indices.shape[:-1] + params.shape[indices.shape[-1]:]\n let nResult = 1;\n for (let i = 0; i < indicesShape.length - 1; ++i) {\n nResult *= indicesShape[i];\n }\n\n const inputShape = tensor.shape;\n\n const resultShape = indicesShape.slice();\n resultShape.pop();\n\n let sliceSize = 1;\n for (let i = sliceRank; i < tensor.rank; ++i) {\n sliceSize *= inputShape[i];\n resultShape.push(inputShape[i]);\n }\n\n const strides =\n [...computeStrides(tensor.shape).map(stride => stride / sliceSize),\n 1].slice(0, sliceRank);\n\n return [resultShape, nResult, sliceSize, strides];\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Inputs of size above this threshold will be parallelized by calling multiple\n * shader programs.\n */\nimport {nearestDivisor} from '../util';\n\nexport const PARALLELIZE_THRESHOLD = 30;\n\nexport interface ReduceInfo {\n windowSize: number;\n batchSize: number;\n inSize: number;\n}\n\nexport function computeOptimalWindowSize(inSize: number): number {\n if (inSize <= PARALLELIZE_THRESHOLD) {\n return inSize;\n }\n return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {TensorInfo} from '../kernel_registry';\nimport {Tensor} from '../tensor';\nimport {computeStrides, sizeFromShape} from '../util';\n\n/**\n * Check whether updates.shape = indices.shape[:batchDim] +\n * shape[sliceDim:]\n *\n * @param x The input tensor.\n */\nexport function validateUpdateShape(\n shape: number[], indices: Tensor, updates: Tensor) {\n const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;\n const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;\n\n const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +\n `shape[sliceDim:], got updates.shape: ${updates.shape}` +\n `, indices.shape: ${indices.shape}, shape: ${shape}` +\n `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;\n\n if (updates.rank < batchDim) {\n throw new Error(shapeError + ` update.rank < ${batchDim}. `);\n }\n if (shape.length < sliceDim + (updates.rank - batchDim)) {\n throw new Error(\n shapeError +\n ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);\n }\n if (updates.rank !== batchDim + shape.length - sliceDim) {\n throw new Error(\n shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);\n }\n for (let d = 0; d < batchDim; ++d) {\n if (updates.shape[d] !== indices.shape[d]) {\n throw new Error(\n shapeError +\n ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${\n indices.shape[d]}).`);\n }\n }\n for (let d = 0; d < updates.rank - batchDim; ++d) {\n if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {\n throw new Error(\n shapeError +\n ` updates.shape[${d + batchDim}] (${\n updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${\n shape[d + batchDim]})`);\n }\n }\n}\n\nexport interface ScatterShapeInfo {\n sliceRank: number;\n numUpdates: number;\n sliceSize: number;\n strides: number[];\n outputSize: number;\n}\n/**\n * Validate scatter nd inputs.\n *\n * @param update The tensor contains the update values.\n * @param indices The tensor contains the indices for the update values.\n * @param shape The shape of the output tensor.\n */\nexport function validateInput(\n updates: Tensor, indices: Tensor, shape: number[]) {\n if (indices.rank < 1) {\n throw new Error(\n 'tf.scatterND() expects the indices to be rank 1 or higher,' +\n ` but the rank was ${indices.rank}.`);\n }\n if (updates.rank < 1) {\n throw new Error(\n 'tf.scatterND() expects the updates to be rank 1 or higher,' +\n ` but the rank was ${updates.rank}.`);\n }\n if (indices.dtype !== 'int32') {\n throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${\n indices.dtype}`);\n }\n if (shape.length < 1) {\n throw new Error(\n `Output rank must be greater or equal to 1, but got shape: ${shape}`);\n }\n\n if (shape.length === 0) {\n if (indices.size === 0) {\n throw new Error(`Indices specified for empty output. indices shape: ${\n indices.shape}`);\n }\n if (updates.size === 0) {\n throw new Error(`Updates specified for empty output. updates shape: ${\n updates.shape}`);\n }\n }\n\n validateUpdateShape(shape, indices, updates);\n}\n\n/**\n * Calculate the shape information for the output.\n *\n * @param update The tensor contains the update values.\n * @param indices The tensor contains the indices for the update values.\n * @param shape The shape of the output tensor.\n *\n * @returns ScatterShapeInfo\n */\nexport function calculateShapes(\n updates: TensorInfo, indices: TensorInfo,\n shape: number[]): ScatterShapeInfo {\n // Calculate the number of dimensions in indices\n const indicesRank = indices.shape.length;\n const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;\n\n // Calculate the number of elements that make up each slice of our updated\n // tensor. This allows us to work with flattened tensors and copy over whole\n // slices at a time.\n const totalNd = shape.length;\n\n let sliceSize = 1;\n for (let i = sliceRank; i < totalNd; ++i) {\n sliceSize *= shape[i];\n }\n\n const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;\n const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;\n\n const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];\n const outputSize = sizeFromShape(shape);\n return {sliceRank, numUpdates, sliceSize, strides, outputSize};\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport * as util from '../util';\n\nexport function assertParamsValid(\n input: Tensor, begin: number[], size: number[]): void {\n util.assert(\n input.rank === begin.length,\n () => `Error in slice${input.rank}D: Length of begin ${begin} must ` +\n `match the rank of the array (${input.rank}).`);\n util.assert(\n input.rank === size.length,\n () => `Error in slice${input.rank}D: Length of size ${size} must ` +\n `match the rank of the array (${input.rank}).`);\n\n for (let i = 0; i < input.rank; ++i) {\n util.assert(\n begin[i] + size[i] <= input.shape[i],\n () => `Error in slice${input.rank}D: begin[${i}] + size[${i}] ` +\n `(${begin[i] + size[i]}) would overflow input.shape[${i}] (${\n input.shape[i]})`);\n }\n}\n\n/** Converts a binary mask to an array of axes. Used in stridedSlice(). */\nexport function maskToAxes(mask: number): number[] {\n const axes = [];\n let axis = 0;\n while (mask > 0) {\n if (mask & 1) {\n axes.push(axis);\n }\n mask /= 2;\n axis++;\n }\n return axes;\n}\n\n/** Computes the output shape given the strided slice params. */\nexport function computeOutShape(\n begin: number[], end: number[], strides: number[]): number[] {\n const size = [];\n for (let axis = 0; axis < begin.length; axis++) {\n size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);\n }\n return size;\n}\n\nexport function startForAxis(\n beginMask: number, startIndices: number[], strides: number[],\n inputShape: number[], axis: number): number {\n // Begin with the specified index\n let start = startIndices[axis];\n const stride = strides[axis] || 1;\n\n // Check the axis bit from right of beginMask or the begin index is not set\n // for the axis.\n if (beginMask & 1 << axis || start == null) {\n if (stride > 0) {\n // Forward iteration - use the first element. These values will get\n // clamped below (Note: We could have set them to 0 and axis_size-1, but\n // use lowest() and max() to maintain symmetry with StopForAxis())\n start = Number.MIN_SAFE_INTEGER;\n } else {\n // Backward iteration - use the last element.\n start = Number.MAX_SAFE_INTEGER;\n }\n }\n\n // Handle negative indices\n const axisSize = inputShape[axis];\n if (start < 0) {\n start += axisSize;\n }\n\n // Clamping\n start = util.clamp(0, start, axisSize - 1);\n\n return start;\n}\n\nexport function stopForAxis(\n endMask: number, stopIndices: number[], strides: number[],\n inputShape: number[], axis: number): number {\n // Begin with the specified index\n let stop = stopIndices[axis];\n const stride = strides[axis] || 1;\n\n // Check the axis bit from right of endMask or if the stop index is not set\n // for this axis.\n if (endMask & (1 << axis) || stop == null) {\n if (stride > 0) {\n // Forward iteration - use the last element. These values will get\n // clamped below\n stop = Number.MAX_SAFE_INTEGER;\n } else {\n // Backward iteration - use the first element.\n stop = Number.MIN_SAFE_INTEGER;\n }\n }\n\n // Handle negative indices\n const axisSize = inputShape[axis];\n if (stop < 0) {\n stop += axisSize;\n }\n\n // Clamping\n // Because the end index points one past the last element, we need slightly\n // different clamping ranges depending on the direction.\n if (stride > 0) {\n // Forward iteration\n stop = util.clamp(0, stop, axisSize);\n } else {\n // Backward iteration\n stop = util.clamp(-1, stop, axisSize - 1);\n }\n\n return stop;\n}\n\n/**\n * Returns true if the slice occupies a continous set of elements in the\n * 'flat' space.\n */\nexport function isSliceContinous(\n shape: number[], begin: number[], size: number[]) {\n // Index of the first axis that has size > 1.\n let firstNonOneAxis = size.length;\n for (let i = 0; i < size.length; i++) {\n if (size[i] > 1) {\n firstNonOneAxis = i;\n break;\n }\n }\n\n for (let i = firstNonOneAxis + 1; i < size.length; i++) {\n if (begin[i] > 0 || size[i] !== shape[i]) {\n return false;\n }\n }\n return true;\n}\n\nexport function computeFlatOffset(begin: number[], strides: number[]): number {\n let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;\n for (let i = 0; i < begin.length - 1; i++) {\n flatOffset += begin[i] * strides[i];\n }\n return flatOffset;\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {CustomGradientFunc, ENGINE} from './engine';\nimport {Scalar, Tensor, Variable} from './tensor';\nimport {NamedTensorMap} from './tensor_types';\nimport {convertToTensor, convertToTensorArray} from './tensor_util_env';\nimport {TensorLike} from './types';\nimport * as util from './util';\n\n/**\n * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the\n * gradient of `f(x)` with respect to `x`.\n *\n * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to\n * `x` is computed instead. `f(x)` must take a single tensor `x` and return a\n * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.\n *\n * ```js\n * // f(x) = x ^ 2\n * const f = x => x.square();\n * // f'(x) = 2x\n * const g = tf.grad(f);\n *\n * const x = tf.tensor1d([2, 3]);\n * g(x).print();\n * ```\n *\n * ```js\n * // f(x) = x ^ 3\n * const f = x => x.pow(tf.scalar(3, 'int32'));\n * // f'(x) = 3x ^ 2\n * const g = tf.grad(f);\n * // f''(x) = 6x\n * const gg = tf.grad(g);\n *\n * const x = tf.tensor1d([2, 3]);\n * gg(x).print();\n * ```\n *\n * @param f The function f(x), to compute gradient for.\n */\n/** @doc {heading: 'Training', subheading: 'Gradients'} */\nfunction grad(f: (x: Tensor) => Tensor): (\n x: TensorLike|Tensor, dy?: TensorLike|Tensor) => Tensor {\n util.assert(\n util.isFunction(f), () => 'The f passed in grad(f) must be a function');\n return (x: TensorLike|Tensor, dy?: TensorLike|Tensor): Tensor => {\n // x can be of any dtype, thus null as the last argument.\n const $x = convertToTensor(x, 'x', 'tf.grad', null);\n const $dy: Tensor =\n (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;\n return ENGINE.tidy(() => {\n const {value, grads} = ENGINE.gradients(() => f($x), [$x], $dy);\n if ($dy != null) {\n util.assertShapesMatch(\n value.shape, $dy.shape,\n 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +\n 'returned by f(x)');\n }\n checkGrads(grads);\n return grads[0];\n });\n };\n}\n\n/**\n * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,\n * which gives an array of gradients of `f()` with respect to each input\n * [`x1`,`x2`,...].\n *\n * If `dy` is passed when calling `g()`, the gradient of\n * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.\n * The provided `f` must take one or more tensors and return a single tensor\n * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.\n *\n * ```js\n * // f(a, b) = a * b\n * const f = (a, b) => a.mul(b);\n * // df / da = b, df / db = a\n * const g = tf.grads(f);\n *\n * const a = tf.tensor1d([2, 3]);\n * const b = tf.tensor1d([-2, -3]);\n * const [da, db] = g([a, b]);\n * console.log('da');\n * da.print();\n * console.log('db');\n * db.print();\n * ```\n *\n * @param f The function `f(x1, x2,...)` to compute gradients for.\n */\n/** @doc {heading: 'Training', subheading: 'Gradients'} */\nfunction grads(f: (...args: Tensor[]) => Tensor): (\n args: Array, dy?: Tensor|TensorLike) => Tensor[] {\n util.assert(\n util.isFunction(f), () => 'The f passed in grads(f) must be a function');\n return (args: Array, dy?: Tensor|TensorLike): Tensor[] => {\n util.assert(\n Array.isArray(args),\n () => 'The args passed in grads(f)(args) must be an array ' +\n 'of `Tensor`s or `TensorLike`s');\n // args can be of any dtype, thus null as the last argument.\n const $args = convertToTensorArray(args, 'args', 'tf.grads', null);\n const $dy: Tensor =\n (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;\n return ENGINE.tidy(() => {\n const {value, grads} = ENGINE.gradients(() => f(...$args), $args, $dy);\n if ($dy != null) {\n util.assertShapesMatch(\n value.shape, $dy.shape,\n 'The shape of dy passed in grads(f)([x1,...], dy) must ' +\n 'match the shape returned by f([x1,...])');\n }\n checkGrads(grads);\n return grads;\n });\n };\n}\n\n/**\n * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`\n * returns a metric you want to show.\n *\n * The result is a rich object with the following properties:\n * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`).\n * - value: The value returned by `f(x)`.\n *\n * ```js\n * // f(x) = x ^ 2\n * const f = x => x.square();\n * // f'(x) = 2x\n * const g = tf.valueAndGrad(f);\n *\n * const x = tf.tensor1d([2, 3]);\n * const {value, grad} = g(x);\n *\n * console.log('value');\n * value.print();\n * console.log('grad');\n * grad.print();\n * ```\n */\n/** @doc {heading: 'Training', subheading: 'Gradients'} */\nfunction valueAndGrad(f: (x: I) => O): (\n x: I, dy?: O) => {\n value: O;\n grad: I;\n} {\n util.assert(\n util.isFunction(f),\n () => 'The f passed in valueAndGrad(f) must be a function');\n return (x: I, dy?: O) => {\n util.assert(\n x instanceof Tensor,\n () => 'The x passed in valueAndGrad(f)(x) must be a tensor');\n util.assert(\n dy == null || dy instanceof Tensor,\n () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');\n const {grads, value} = ENGINE.gradients(() => f(x), [x], dy);\n checkGrads(grads);\n return {grad: grads[0] as I, value};\n };\n}\n\n/**\n * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`\n * returns a metric you want to show.\n *\n * The result is a rich object with the following properties:\n * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`).\n * - value: The value returned by `f(x)`.\n *\n * ```js\n * // f(a, b) = a * b\n * const f = (a, b) => a.mul(b);\n * // df/da = b, df/db = a\n * const g = tf.valueAndGrads(f);\n *\n * const a = tf.tensor1d([2, 3]);\n * const b = tf.tensor1d([-2, -3]);\n * const {value, grads} = g([a, b]);\n *\n * const [da, db] = grads;\n *\n * console.log('value');\n * value.print();\n *\n * console.log('da');\n * da.print();\n * console.log('db');\n * db.print();\n * ```\n */\n/** @doc {heading: 'Training', subheading: 'Gradients'} */\nfunction valueAndGrads(f: (...args: Tensor[]) => O): (\n args: Tensor[], dy?: O) => {\n grads: Tensor[];\n value: O;\n} {\n util.assert(\n util.isFunction(f),\n () => 'The f passed in valueAndGrads(f) must be a function');\n return (args: Tensor[], dy?: O) => {\n util.assert(\n Array.isArray(args) && args.every(arg => arg instanceof Tensor),\n () => 'The args passed in valueAndGrads(f)(args) must be array of ' +\n 'tensors');\n util.assert(\n dy == null || dy instanceof Tensor,\n () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');\n const res = ENGINE.gradients(() => f(...args), args, dy);\n if (dy != null) {\n util.assertShapesMatch(\n res.value.shape, dy.shape,\n 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +\n 'match the shape returned by f([x1,...])');\n }\n checkGrads(res.grads);\n return res;\n };\n}\n\n/**\n * Computes and returns the gradient of f(x) with respect to the list of\n * trainable variables provided by `varList`. If no list is provided, it\n * defaults to all trainable variables.\n *\n * ```js\n * const a = tf.variable(tf.tensor1d([3, 4]));\n * const b = tf.variable(tf.tensor1d([5, 6]));\n * const x = tf.tensor1d([1, 2]);\n *\n * // f(a, b) = a * x ^ 2 + b * x\n * const f = () => a.mul(x.square()).add(b.mul(x)).sum();\n * // df/da = x ^ 2, df/db = x\n * const {value, grads} = tf.variableGrads(f);\n *\n * Object.keys(grads).forEach(varName => grads[varName].print());\n * ```\n *\n * @param f The function to execute. f() should return a scalar.\n * @param varList The list of variables to compute the gradients with respect\n * to. Defaults to all trainable variables.\n * @returns An object with the following keys and values:\n * - `value`: The value of the function `f`.\n * - `grads`: A map from the names of the variables to the gradients.\n * If the `varList` argument is provided explicitly and contains a subset of\n * non-trainable variables, this map in the return value will contain keys\n * that map the names of the non-trainable variables to `null`.\n */\n/** @doc {heading: 'Training', subheading: 'Gradients'} */\nfunction variableGrads(f: () => Scalar, varList?: Variable[]):\n {value: Scalar, grads: NamedTensorMap} {\n util.assert(\n util.isFunction(f),\n () => 'The f passed in variableGrads(f) must be a function');\n util.assert(\n varList == null ||\n Array.isArray(varList) && varList.every(v => v instanceof Variable),\n () =>\n 'The varList passed in variableGrads(f, varList) must be an array ' +\n 'of variables');\n\n const specifiedVarList = varList != null;\n if (!specifiedVarList) {\n // Get all of the trainable variables.\n varList = [];\n for (const varName in ENGINE.registeredVariables) {\n varList.push(ENGINE.registeredVariables[varName]);\n }\n }\n\n const specifiedNonTrainable: Variable[] =\n specifiedVarList ? varList.filter(variable => !variable.trainable) : null;\n\n // Prune non-trainable variables.\n const originalVarCount = varList.length;\n varList = varList.filter(variable => variable.trainable);\n util.assert(\n varList.length > 0,\n () => `variableGrads() expects at least one of the input variables to ` +\n `be trainable, but none of the ${originalVarCount} variables is ` +\n `trainable.`);\n\n const allowNoGradients = true;\n const {value, grads} = ENGINE.gradients(f, varList, null, allowNoGradients);\n\n util.assert(\n grads.some(g => g != null),\n () => 'Cannot find a connection between any variable and the result of ' +\n 'the loss function y=f(x). Please make sure the operations that ' +\n 'use variables are inside the function f passed to minimize().');\n util.assert(\n value.rank === 0,\n () => `The f passed in variableGrads(f) must return a scalar, but it ` +\n `returned a rank-${value.rank} tensor`);\n\n const namedGrads: NamedTensorMap = {};\n varList.forEach((v, i) => {\n if (grads[i] != null) {\n namedGrads[v.name] = grads[i];\n }\n });\n if (specifiedNonTrainable != null) {\n // If varList is explicitly provided and contains non-trainable values,\n // add them to the returned gradients with `null` values.\n specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);\n }\n return {value, grads: namedGrads};\n}\n\n/**\n * Overrides the gradient computation of a function `f`.\n *\n * Takes a function\n * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`\n * and returns another function `g(...inputs)` which takes the same inputs as\n * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients\n * with respect to each input of `f` are computed using `f().gradFunc`.\n *\n * The `save` function passsed to `f` should be used for saving tensors needed\n * in the gradient. And the `saved` passed to the `gradFunc` is a\n * `NamedTensorMap`, which contains those saved tensor.\n *\n * ```js\n * const customOp = tf.customGrad((x, save) => {\n * // Save x to make sure it's available later for the gradient.\n * save([x]);\n * // Override gradient of our custom x ^ 2 op to be dy * abs(x);\n * return {\n * value: x.square(),\n * // Note `saved.x` which points to the `x` we saved earlier.\n * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]\n * };\n * });\n *\n * const x = tf.tensor1d([-1, -2, 3]);\n * const dx = tf.grad(x => customOp(x));\n *\n * console.log(`f(x):`);\n * customOp(x).print();\n * console.log(`f'(x):`);\n * dx(x).print();\n * ```\n *\n * @param f The function to evaluate in forward mode, which should return\n * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`\n * returns the custom gradients of `f` with respect to its inputs.\n */\n/** @doc {heading: 'Training', subheading: 'Gradients'} */\nfunction customGrad(f: CustomGradientFunc):\n (...args: Tensor[]) => T {\n return ENGINE.customGrad(f);\n}\n\nfunction checkGrads(grads: Tensor[]) {\n const numNullGradients = grads.filter(g => g == null).length;\n if (numNullGradients > 0) {\n throw new Error(\n `Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.`);\n }\n}\n\nexport {\n customGrad,\n variableGrads,\n valueAndGrad,\n valueAndGrads,\n grad,\n grads,\n};\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {customGrad} from '../gradients';\nimport {Tensor} from '../tensor';\nimport {GradSaveFunc} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {op} from './operation';\n\n/**\n * Computes the softmax normalized vector given the logits.\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n *\n * a.softmax().print(); // or tf.softmax(a)\n * ```\n *\n * ```js\n * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);\n *\n * a.softmax().print(); // or tf.softmax(a)\n * ```\n *\n * @param logits The logits array.\n * @param dim The dimension softmax would be performed on. Defaults to `-1`\n * which indicates the last dimension.\n */\n/** @doc {heading: 'Operations', subheading: 'Normalization'} */\nfunction softmax_(logits: T|TensorLike, dim = -1): T {\n const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');\n\n if (dim === -1) {\n dim = $logits.rank - 1;\n }\n if (dim !== $logits.rank - 1) {\n throw Error(\n 'Softmax along a non-last dimension is not yet supported. ' +\n `Logits was rank ${$logits.rank} and dim was ${dim}`);\n }\n\n const inputsToSave: Tensor[] = [];\n const outputsToSave = [true];\n\n return ENGINE.runKernelFunc(\n (backend, save) => {\n const y = backend.softmax($logits, dim);\n save([y]);\n return y;\n },\n {logits: $logits},\n (dy: T, saved: Tensor[]) => {\n const [y] = saved;\n const dyTimesY = dy.mul(y);\n const keepDims = true;\n\n return {\n logits: () => dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y))\n };\n },\n 'Softmax', {dim}, inputsToSave, outputsToSave);\n}\n\n/**\n * Computes the log softmax.\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n *\n * a.logSoftmax().print(); // or tf.logSoftmax(a)\n * ```\n *\n * ```js\n * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);\n *\n * a.logSoftmax().print(); // or tf.logSoftmax(a)\n * ```\n *\n * @param logits The logits array.\n * @param axis The dimension softmax would be performed on. Defaults to `-1`\n * which indicates the last dimension.\n */\n/** @doc {heading: 'Operations', subheading: 'Normalization'} */\nfunction logSoftmax_(logits: T|TensorLike, axis = -1): T {\n const $logits = convertToTensor(logits, 'logits', 'logSoftmax');\n\n if (axis === -1) {\n axis = $logits.rank - 1;\n }\n if (axis !== $logits.rank - 1) {\n throw Error(\n 'Log Softmax along a non-last dimension is not yet supported. ' +\n `Logits was rank ${$logits.rank} and axis was ${axis}`);\n }\n\n const customOp = customGrad((logits: Tensor, save: GradSaveFunc) => {\n const keepDims = true;\n const xMax = logits.max(axis, true);\n const shifted = logits.sub(xMax);\n const value =\n shifted.toFloat().sub(shifted.exp().sum(axis, keepDims).log());\n save([value]);\n const gradFunc = (dy: T, saved: Tensor[]) => {\n const [value] = saved;\n const softmax = value.exp();\n return dy.sub(dy.sum(axis, keepDims).mul(softmax));\n };\n\n return {value, gradFunc};\n });\n\n return customOp($logits) as T;\n}\n\nexport const softmax = op({softmax_});\nexport const logSoftmax = op({logSoftmax_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport {op} from './operation';\n\n/**\n * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.\n *\n * The returned `tf.Tensor`'s dimension `i` will correspond to the input\n * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,\n * where `n` is the rank of the input `tf.Tensor`. Hence by default, this\n * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.\n *\n * ```js\n * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);\n *\n * a.transpose().print(); // or tf.transpose(a)\n * ```\n *\n * @param x The tensor to transpose.\n * @param perm The permutation of the dimensions of a.\n */\n/** @doc {heading: 'Operations', subheading: 'Matrices'} */\nfunction transpose_(x: T|TensorLike, perm?: number[]): T {\n const $x = convertToTensor(x, 'x', 'transpose');\n\n if (perm == null) {\n perm = $x.shape.map((s, i) => i).reverse();\n }\n util.assert(\n $x.rank === perm.length,\n () => `Error in transpose: rank of input ${$x.rank} ` +\n `must match length of perm ${perm}.`);\n perm.forEach(axis => {\n util.assert(\n axis >= 0 && axis < $x.rank,\n () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +\n ` but got ${perm}`);\n });\n\n if ($x.rank <= 1) {\n return $x.clone();\n }\n\n const attrs = {perm};\n return ENGINE.runKernelFunc(\n backend => backend.transpose($x, perm), {x: $x}, null /* gradient */,\n 'Transpose', attrs);\n}\n\nexport const transpose = op({transpose_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';\nimport {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';\nimport {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';\nimport {BackendValues, DataType, Rank, ShapeMap} from '../types';\n\nexport const EPSILON_FLOAT32 = 1e-7;\nexport const EPSILON_FLOAT16 = 1e-4;\n\n// Required information for all backends.\nexport interface BackendTimingInfo {\n kernelMs: number|{error: string};\n getExtraProfileInfo?(): string; // a field for additional timing information\n // e.g. packing / unpacking for WebGL backend\n}\n\nexport interface TensorStorage {\n read(dataId: DataId): Promise;\n readSync(dataId: DataId): BackendValues;\n disposeData(dataId: DataId): void;\n write(values: BackendValues, shape: number[], dtype: DataType): DataId;\n move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):\n void;\n memory(): {unreliable: boolean;}; // Backend-specific information.\n /** Returns number of data ids currently in the storage. */\n numDataIds(): number;\n}\n\n/** Convenient class for storing tensor-related data. */\nexport class DataStorage {\n private data = new WeakMap();\n private dataIdsCount = 0;\n\n constructor(private backend: KernelBackend, private dataMover: DataMover) {}\n\n get(dataId: DataId) {\n if (!this.data.has(dataId)) {\n this.dataMover.moveData(this.backend, dataId);\n }\n return this.data.get(dataId);\n }\n\n set(dataId: DataId, value: T): void {\n this.dataIdsCount++;\n this.data.set(dataId, value);\n }\n\n has(dataId: DataId): boolean {\n return this.data.has(dataId);\n }\n\n delete(dataId: DataId): boolean {\n this.dataIdsCount--;\n return this.data.delete(dataId);\n }\n\n numDataIds(): number {\n return this.dataIdsCount;\n }\n}\n\nexport interface DataMover {\n /**\n * To be called by backends whenever they see a dataId that they don't own.\n * Upon calling this method, the mover will fetch the tensor from another\n * backend and register it with the current active backend.\n */\n moveData(backend: KernelBackend, dataId: DataId): void;\n}\n\nexport interface BackendTimer {\n time(f: () => void): Promise;\n}\n\n/**\n * The interface that defines the kernels that should be implemented when\n * adding a new backend. New backends don't need to implement every one of the\n * methods, this can be done gradually (throw an error for unimplemented\n * methods).\n */\nexport class KernelBackend implements TensorStorage, Backend, BackendTimer {\n time(f: () => void): Promise {\n return notYetImplemented('time');\n }\n read(dataId: object): Promise {\n return notYetImplemented('read');\n }\n readSync(dataId: object): BackendValues {\n return notYetImplemented('readSync');\n }\n numDataIds(): number {\n return notYetImplemented('numDataIds');\n }\n disposeData(dataId: object): void {\n return notYetImplemented('disposeData');\n }\n write(values: BackendValues, shape: number[], dtype: DataType): DataId {\n return notYetImplemented('write');\n }\n move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):\n void {\n return notYetImplemented('move');\n }\n memory(): {unreliable: boolean; reasons?: string[]} {\n return notYetImplemented('memory');\n }\n /** Returns the highest precision for floats in bits (e.g. 16 or 32) */\n floatPrecision(): 16|32 {\n return notYetImplemented('floatPrecision');\n }\n /** Returns the smallest representable number. */\n epsilon(): number {\n return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;\n }\n\n batchMatMul(\n a: Tensor3D, b: Tensor3D, transposeA: boolean,\n transposeB: boolean): Tensor3D {\n return notYetImplemented('batchMatMul');\n }\n\n fusedBatchMatMul(\n {a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:\n FusedBatchMatMulConfig): Tensor3D {\n return notYetImplemented('fusedBatchMatMul');\n }\n\n slice(x: T, begin: number[], size: number[]): T {\n return notYetImplemented('slice');\n }\n stridedSlice(\n x: T, begin: number[], end: number[], strides: number[]): T {\n return notYetImplemented('stridedSlice');\n }\n unstack(x: Tensor, axis: number): Tensor[] {\n return notYetImplemented('unstack');\n }\n reverse(a: T, axis: number[]): T {\n return notYetImplemented('reverse');\n }\n\n concat(tensors: Tensor[], axis: number): Tensor {\n return notYetImplemented('concat');\n }\n\n neg(a: T): T {\n return notYetImplemented('neg');\n }\n\n add(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('add');\n }\n addN(tensors: T[]): T {\n return notYetImplemented('addN');\n }\n subtract(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('subtract');\n }\n multiply(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('multiply');\n }\n realDivide(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('realDivide');\n }\n floorDiv(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('floorDiv');\n }\n\n sum(x: Tensor, axes: number[]): Tensor {\n return notYetImplemented('sum');\n }\n prod(x: Tensor, axes: number[]): Tensor {\n return notYetImplemented('prod');\n }\n\n unsortedSegmentSum(\n x: T, segmentIds: Tensor1D, numSegments: number): Tensor {\n return notYetImplemented('unsortedSegmentSum');\n }\n\n argMin(x: Tensor, axis: number): Tensor {\n return notYetImplemented('argMin');\n }\n argMax(x: Tensor, axis: number): Tensor {\n return notYetImplemented('argMax');\n }\n\n equal(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('equal');\n }\n notEqual(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('notEqual');\n }\n\n less(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('less');\n }\n lessEqual(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('lessEqual');\n }\n\n greater(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('greater');\n }\n greaterEqual(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('greaterEqual');\n }\n\n logicalNot(a: T): T {\n return notYetImplemented('logicalNot');\n }\n logicalAnd(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('logicalAnd');\n }\n logicalOr(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('logicalOr');\n }\n\n where(condition: Tensor): Tensor2D {\n return notYetImplemented('where');\n }\n select(condition: Tensor, a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('select');\n }\n\n topk(x: T, k: number, sorted: boolean): [T, T] {\n return notYetImplemented('topk');\n }\n\n min(x: Tensor, axes: number[]): Tensor {\n return notYetImplemented('min');\n }\n minimum(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('minimum');\n }\n\n mod(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('mod');\n }\n\n max(x: Tensor, axes: number[]): Tensor {\n return notYetImplemented('max');\n }\n maximum(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('maximum');\n }\n\n all(x: Tensor, axes: number[]): Tensor {\n return notYetImplemented('all');\n }\n any(x: Tensor, axes: number[]): Tensor {\n return notYetImplemented('any');\n }\n\n squaredDifference(a: Tensor, b: Tensor): Tensor {\n return notYetImplemented('squaredDifference');\n }\n\n ceil(x: T): T {\n return notYetImplemented('ceil');\n }\n floor(x: T): T {\n return notYetImplemented('floor');\n }\n round(x: T): T {\n return notYetImplemented('round');\n }\n\n sign(x: T): T {\n return notYetImplemented('sign');\n }\n\n isNaN(x: T): T {\n return notYetImplemented('isNaN');\n }\n isInf(x: T): T {\n return notYetImplemented('isInf');\n }\n isFinite(x: T): T {\n return notYetImplemented('isFinite');\n }\n\n pow(a: T, b: Tensor): T {\n return notYetImplemented('pow');\n }\n exp(x: T): T {\n return notYetImplemented('exp');\n }\n expm1(x: T): T {\n return notYetImplemented('expm1');\n }\n softmax(x: T, dim: number): T {\n return notYetImplemented('softmax');\n }\n log(x: T): T {\n return notYetImplemented('log');\n }\n log1p(x: T): T {\n return notYetImplemented('log1p');\n }\n sqrt(x: T): T {\n return notYetImplemented('sqrt');\n }\n rsqrt(x: T): T {\n return notYetImplemented('rsqrt');\n }\n square(x: T): T {\n return notYetImplemented('square');\n }\n reciprocal(x: T): T {\n return notYetImplemented('reciprocal');\n }\n relu(x: T): T {\n return notYetImplemented('relu');\n }\n relu6(x: T): T {\n return notYetImplemented('relu6');\n }\n prelu(x: T, a: T): T {\n return notYetImplemented('prelu');\n }\n elu(x: T): T {\n return notYetImplemented('elu');\n }\n eluDer(dy: T, y: T): T {\n return notYetImplemented('eluDer');\n }\n selu(x: T): T {\n return notYetImplemented('selu');\n }\n int(x: T): T {\n return notYetImplemented('int');\n }\n\n clip(x: T, min: number, max: number): T {\n return notYetImplemented('clip');\n }\n\n abs(x: T): T {\n return notYetImplemented('abs');\n }\n complexAbs(x: T): T {\n return notYetImplemented('complexAbs');\n }\n\n sigmoid(x: T): T {\n return notYetImplemented('sigmoid');\n }\n\n softplus(x: T): T {\n return notYetImplemented('softplus');\n }\n\n sin(x: T): T {\n return notYetImplemented('sin');\n }\n cos(x: T): T {\n return notYetImplemented('cos');\n }\n tan(x: T): T {\n return notYetImplemented('tan');\n }\n\n asin(x: T): T {\n return notYetImplemented('asin');\n }\n acos(x: T): T {\n return notYetImplemented('acos');\n }\n atan(x: T): T {\n return notYetImplemented('atan');\n }\n atan2(a: T, b: T): T {\n return notYetImplemented('atan2');\n }\n\n sinh(x: T): T {\n return notYetImplemented('sinh');\n }\n cosh(x: T): T {\n return notYetImplemented('cosh');\n }\n tanh(x: T): T {\n return notYetImplemented('tanh');\n }\n\n asinh(x: T): T {\n return notYetImplemented('asinh');\n }\n acosh(x: T): T {\n return notYetImplemented('acosh');\n }\n atanh(x: T): T {\n return notYetImplemented('atanh');\n }\n\n erf(x: T): T {\n return notYetImplemented('erf');\n }\n\n step(x: T, alpha: number): T {\n return notYetImplemented('step');\n }\n\n fusedConv2d(\n {input, filter, convInfo, bias, activation, preluActivationWeights}:\n FusedConv2DConfig): Tensor4D {\n return notYetImplemented('fusedConv2d');\n }\n\n conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n return notYetImplemented('conv2d');\n }\n conv2dDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n return notYetImplemented('conv2dDerInput');\n }\n conv2dDerFilter(x: Tensor4D, dY: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n return notYetImplemented('conv2dDerFilter');\n }\n\n fusedDepthwiseConv2D(\n {input, filter, convInfo, bias, activation, preluActivationWeights}:\n FusedConv2DConfig): Tensor4D {\n return notYetImplemented('fusedDepthwiseConv2D');\n }\n\n depthwiseConv2D(input: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n return notYetImplemented('depthwiseConv2D');\n }\n depthwiseConv2DDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n return notYetImplemented('depthwiseConv2DDerInput');\n }\n depthwiseConv2DDerFilter(x: Tensor4D, dY: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n return notYetImplemented('depthwiseConv2DDerFilter');\n }\n conv3d(x: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n return notYetImplemented('conv3d');\n }\n conv3dDerInput(dy: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo):\n Tensor5D {\n return notYetImplemented('conv3dDerInput');\n }\n conv3dDerFilter(x: Tensor5D, dY: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n return notYetImplemented('conv3dDerFilter');\n }\n maxPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n return notYetImplemented('maxPool');\n }\n maxPoolBackprop(dy: Tensor4D, x: Tensor4D, y: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n return notYetImplemented('maxPoolBackprop');\n }\n avgPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n return notYetImplemented('avgPool');\n }\n avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n return notYetImplemented('avgPoolBackprop');\n }\n avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n return notYetImplemented('avgPool3d');\n }\n avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n return notYetImplemented('avgPool3dBackprop');\n }\n maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n return notYetImplemented('maxPool3d');\n }\n maxPool3dBackprop(\n dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n return notYetImplemented('maxPool3dBackprop');\n }\n\n reshape(x: T, shape: ShapeMap[R]):\n Tensor {\n return notYetImplemented('reshape');\n }\n cast(x: T, dtype: DataType): T {\n return notYetImplemented('cast');\n }\n\n tile(x: T, reps: number[]): T {\n return notYetImplemented('tile');\n }\n\n pad(\n x: T, paddings: Array<[number, number]>, constantValue: number): T {\n return notYetImplemented('pad');\n }\n\n transpose(x: T, perm: number[]): T {\n return notYetImplemented('transpose');\n }\n\n gather(x: T, indices: Tensor1D, axis: number): T {\n return notYetImplemented('gather');\n }\n\n gatherND(x: Tensor, indices: Tensor): Tensor {\n return notYetImplemented('gatherND');\n }\n\n scatterND(\n indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor {\n return notYetImplemented('scatterND');\n }\n\n batchToSpaceND(\n x: T, blockShape: number[], crops: number[][]): T {\n return notYetImplemented('batchToSpaceND');\n }\n\n spaceToBatchND(\n x: T, blockShape: number[], paddings: number[][]): T {\n return notYetImplemented('spaceToBatchND');\n }\n\n resizeBilinear(\n x: Tensor4D, newHeight: number, newWidth: number,\n alignCorners: boolean): Tensor4D {\n return notYetImplemented('resizeBilinear');\n }\n\n resizeBilinearBackprop(dy: Tensor4D, x: Tensor4D, alignCorners: boolean):\n Tensor4D {\n return notYetImplemented('resizeBilinearBackprop');\n }\n\n resizeNearestNeighbor(\n x: Tensor4D, newHEight: number, newWidth: number,\n alignCorners: boolean): Tensor4D {\n return notYetImplemented('resizeNearestNeighbor');\n }\n\n resizeNearestNeighborBackprop(\n dy: Tensor4D, x: Tensor4D, alignCorners: boolean): Tensor4D {\n return notYetImplemented('resizeNearestNeighborBackprop');\n }\n\n batchNormalization(\n x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D,\n varianceEpsilon: number, scale?: Tensor4D|Tensor1D,\n offset?: Tensor4D|Tensor1D): Tensor4D {\n return notYetImplemented('batchNormalization');\n }\n\n localResponseNormalization4D(\n x: Tensor4D, radius: number, bias: number, alpha: number,\n beta: number): Tensor4D {\n return notYetImplemented('localResponseNormalization4D');\n }\n\n LRNGrad(\n dy: Tensor4D, inputImage: Tensor4D, outputImage: Tensor4D, radius: number,\n bias: number, alpha: number, beta: number): Tensor4D {\n return notYetImplemented('LRNGrad');\n }\n\n multinomial(\n logits: Tensor2D, normalized: boolean, numSamples: number,\n seed: number): Tensor2D {\n return notYetImplemented('multinomial');\n }\n\n oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number):\n Tensor2D {\n return notYetImplemented('oneHot');\n }\n\n cumsum(x: Tensor, axis: number, exclusive: boolean, reverse: boolean):\n Tensor {\n return notYetImplemented('cumsum');\n }\n\n nonMaxSuppression(\n boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,\n iouThreshold: number, scoreThreshold?: number): Tensor1D {\n return notYetImplemented('nonMaxSuppression');\n }\n\n fft(x: Tensor2D): Tensor2D {\n return notYetImplemented('fft');\n }\n ifft(x: Tensor2D): Tensor2D {\n return notYetImplemented('ifft');\n }\n complex(real: T, imag: T): T {\n return notYetImplemented('complex');\n }\n real(input: T): T {\n return notYetImplemented('real');\n }\n imag(input: T): T {\n return notYetImplemented('imag');\n }\n\n cropAndResize(\n image: Tensor4D, boxes: Tensor2D, boxIndex: Tensor1D,\n cropSize: [number, number], method: 'bilinear'|'nearest',\n extrapolationValue: number): Tensor4D {\n return notYetImplemented('cropAndResize');\n }\n\n depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D {\n return notYetImplemented('depthToSpace');\n }\n\n // Aligns with the \"SplitV\" kernel in TensorFlow.\n split(value: T, sizeSplits: number[], axis: number): T[] {\n return notYetImplemented('split');\n }\n\n sparseToDense(\n sparseIndices: Tensor, sparseValues: Tensor, outputShape: ShapeMap[R],\n defaultValue: Scalar): Tensor {\n return notYetImplemented('sparseToDense');\n }\n\n diag(x: Tensor): Tensor {\n return notYetImplemented('diag');\n }\n\n fill(\n shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor {\n return notYetImplemented('fill');\n }\n\n onesLike(x: Tensor): Tensor {\n return notYetImplemented('onesLike');\n }\n\n zerosLike(x: Tensor): Tensor {\n return notYetImplemented('zerosLike');\n }\n\n linspace(start: number, stop: number, num: number): Tensor1D {\n return notYetImplemented('linspace');\n }\n\n dispose(): void {\n return notYetImplemented('dispose');\n }\n}\n\nfunction notYetImplemented(kernelName: string): never {\n throw new Error(\n `'${kernelName}' not yet implemented or not found in the registry. ` +\n `Did you forget to import the kernel?`);\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as util from '../util';\n\ntype PadType = 'SAME'|'VALID'|'NUMBER';\n\nexport type PadInfo = {\n top: number,\n left: number,\n right: number,\n bottom: number,\n type: PadType\n};\n\nexport type PadInfo3D = {\n top: number,\n left: number,\n right: number,\n bottom: number,\n front: number,\n back: number,\n type: PadType\n};\n\n/**\n * Information about the forward pass of a convolution/pooling operation.\n * It includes input and output shape, strides, filter size and padding\n * information.\n */\nexport type Conv2DInfo = {\n batchSize: number,\n inHeight: number,\n inWidth: number,\n inChannels: number,\n outHeight: number,\n outWidth: number,\n outChannels: number,\n dataFormat: 'channelsFirst'|'channelsLast',\n strideHeight: number,\n strideWidth: number,\n dilationHeight: number,\n dilationWidth: number,\n filterHeight: number,\n filterWidth: number,\n effectiveFilterHeight: number,\n effectiveFilterWidth: number,\n padInfo: PadInfo,\n inShape: [number, number, number, number],\n outShape: [number, number, number, number],\n filterShape: [number, number, number, number]\n};\n\nexport function computePool2DInfo(\n inShape: [number, number, number, number],\n filterSize: [number, number]|number, strides: number|[number, number],\n dilations: number|[number, number], pad: 'same'|'valid'|number,\n roundingMode?: 'floor'|'round'|'ceil',\n dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo {\n const [filterHeight, filterWidth] = parseTupleParam(filterSize);\n\n let filterShape: [number, number, number, number];\n if (dataFormat === 'channelsLast') {\n filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];\n } else if (dataFormat === 'channelsFirst') {\n filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];\n } else {\n throw new Error(`Unknown dataFormat ${dataFormat}`);\n }\n\n return computeConv2DInfo(\n inShape, filterShape, strides, dilations, pad, roundingMode, false,\n dataFormat);\n}\n\n/**\n * Computes the information for a forward pass of a pooling3D operation.\n */\nexport function computePool3DInfo(\n inShape: [number, number, number, number, number],\n filterSize: number|[number, number, number],\n strides: number|[number, number, number],\n dilations: number|[number, number, number], pad: 'same'|'valid'|number,\n roundingMode?: 'floor'|'round'|'ceil',\n dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC'): Conv3DInfo {\n const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize);\n\n let filterShape: [number, number, number, number, number];\n let $dataFormat: 'channelsFirst'|'channelsLast';\n if (dataFormat === 'NDHWC') {\n $dataFormat = 'channelsLast';\n filterShape =\n [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];\n } else if (dataFormat === 'NCDHW') {\n $dataFormat = 'channelsFirst';\n filterShape =\n [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];\n } else {\n throw new Error(`Unknown dataFormat ${dataFormat}`);\n }\n\n return computeConv3DInfo(\n inShape, filterShape, strides, dilations, pad, false, $dataFormat,\n roundingMode);\n}\n\n/**\n * Computes the information for a forward pass of a convolution/pooling\n * operation.\n */\nexport function computeConv2DInfo(\n inShape: [number, number, number, number],\n filterShape: [number, number, number, number],\n strides: number|[number, number], dilations: number|[number, number],\n pad: 'same'|'valid'|number, roundingMode?: 'floor'|'round'|'ceil',\n depthwise = false,\n dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo {\n let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];\n if (dataFormat === 'channelsLast') {\n [batchSize, inHeight, inWidth, inChannels] = inShape;\n } else if (dataFormat === 'channelsFirst') {\n [batchSize, inChannels, inHeight, inWidth] = inShape;\n } else {\n throw new Error(`Unknown dataFormat ${dataFormat}`);\n }\n\n const [filterHeight, filterWidth, , filterChannels] = filterShape;\n const [strideHeight, strideWidth] = parseTupleParam(strides);\n const [dilationHeight, dilationWidth] = parseTupleParam(dilations);\n\n const effectiveFilterHeight =\n getEffectiveFilterSize(filterHeight, dilationHeight);\n const effectiveFilterWidth =\n getEffectiveFilterSize(filterWidth, dilationWidth);\n const {padInfo, outHeight, outWidth} = getPadAndOutInfo(\n pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight,\n effectiveFilterWidth, roundingMode);\n\n const outChannels = depthwise ? filterChannels * inChannels : filterChannels;\n\n let outShape: [number, number, number, number];\n if (dataFormat === 'channelsFirst') {\n outShape = [batchSize, outChannels, outHeight, outWidth];\n } else if (dataFormat === 'channelsLast') {\n outShape = [batchSize, outHeight, outWidth, outChannels];\n }\n\n return {\n batchSize,\n dataFormat,\n inHeight,\n inWidth,\n inChannels,\n outHeight,\n outWidth,\n outChannels,\n padInfo,\n strideHeight,\n strideWidth,\n filterHeight,\n filterWidth,\n effectiveFilterHeight,\n effectiveFilterWidth,\n dilationHeight,\n dilationWidth,\n inShape,\n outShape,\n filterShape\n };\n}\n\n/**\n * Information about the forward pass of a 3D convolution/pooling operation.\n * It includes input and output shape, strides, filter size and padding\n * information.\n */\nexport type Conv3DInfo = {\n batchSize: number,\n inDepth: number,\n inHeight: number,\n inWidth: number,\n inChannels: number,\n outDepth: number,\n outHeight: number,\n outWidth: number,\n outChannels: number,\n dataFormat: 'channelsFirst'|'channelsLast',\n strideDepth: number,\n strideHeight: number,\n strideWidth: number,\n dilationDepth: number,\n dilationHeight: number,\n dilationWidth: number,\n filterDepth: number,\n filterHeight: number,\n filterWidth: number,\n effectiveFilterDepth: number,\n effectiveFilterHeight: number,\n effectiveFilterWidth: number,\n padInfo: PadInfo3D,\n inShape: [number, number, number, number, number],\n outShape: [number, number, number, number, number],\n filterShape: [number, number, number, number, number]\n};\n\n/**\n * Computes the information for a forward pass of a 3D convolution/pooling\n * operation.\n */\nexport function computeConv3DInfo(\n inShape: [number, number, number, number, number],\n filterShape: [number, number, number, number, number],\n strides: number|[number, number, number],\n dilations: number|[number, number, number], pad: 'same'|'valid'|number,\n depthwise = false,\n dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast',\n roundingMode?: 'floor'|'round'|'ceil'): Conv3DInfo {\n let [batchSize, inDepth, inHeight, inWidth, inChannels] =\n [-1, -1, -1, -1, -1];\n if (dataFormat === 'channelsLast') {\n [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;\n } else if (dataFormat === 'channelsFirst') {\n [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;\n } else {\n throw new Error(`Unknown dataFormat ${dataFormat}`);\n }\n\n const [filterDepth, filterHeight, filterWidth, , filterChannels] =\n filterShape;\n const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);\n const [dilationDepth, dilationHeight, dilationWidth] =\n parse3TupleParam(dilations);\n\n const effectiveFilterDepth =\n getEffectiveFilterSize(filterDepth, dilationDepth);\n const effectiveFilterHeight =\n getEffectiveFilterSize(filterHeight, dilationHeight);\n const effectiveFilterWidth =\n getEffectiveFilterSize(filterWidth, dilationWidth);\n const {padInfo, outDepth, outHeight, outWidth} = get3DPadAndOutInfo(\n pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth,\n effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth,\n roundingMode);\n\n const outChannels = depthwise ? filterChannels * inChannels : filterChannels;\n\n let outShape: [number, number, number, number, number];\n if (dataFormat === 'channelsFirst') {\n outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];\n } else if (dataFormat === 'channelsLast') {\n outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];\n }\n\n return {\n batchSize,\n dataFormat,\n inDepth,\n inHeight,\n inWidth,\n inChannels,\n outDepth,\n outHeight,\n outWidth,\n outChannels,\n padInfo,\n strideDepth,\n strideHeight,\n strideWidth,\n filterDepth,\n filterHeight,\n filterWidth,\n effectiveFilterDepth,\n effectiveFilterHeight,\n effectiveFilterWidth,\n dilationDepth,\n dilationHeight,\n dilationWidth,\n inShape,\n outShape,\n filterShape\n };\n}\n\nfunction computeOutputShape2D(\n inShape: [number, number], fieldSize: number, stride: number,\n zeroPad?: number, roundingMode?: 'floor'|'round'|'ceil'): [number, number] {\n if (zeroPad == null) {\n zeroPad = computeDefaultPad(inShape, fieldSize, stride);\n }\n const inputRows = inShape[0];\n const inputCols = inShape[1];\n\n const outputRows = conditionalRound(\n (inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);\n util.assert(\n util.isInt(outputRows),\n () => `The output # of rows (${outputRows}) must be an integer. ` +\n `Change the stride and/or zero pad parameters`);\n\n const outputCols = conditionalRound(\n (inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);\n util.assert(\n util.isInt(outputCols),\n () => `The output # of columns (${outputCols}) must be an integer. ` +\n `Change the stride and/or zero pad parameters`);\n\n return [outputRows, outputCols];\n}\n\nfunction computeOutputShape4D(\n inShape: [number, number, number, number], fieldSize: number,\n outChannels: number, stride: number, zeroPad?: number,\n roundingMode?: 'floor'|'round'|'ceil'): [number, number, number, number] {\n if (zeroPad == null) {\n zeroPad = computeDefaultPad(inShape, fieldSize, stride);\n }\n const inputDepth = inShape[0];\n const inputRows = inShape[1];\n const inputCols = inShape[2];\n\n const outputDepths = conditionalRound(\n (inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);\n util.assert(\n util.isInt(outputDepths),\n () => `The output # of depths (${outputDepths}) must be an integer. ` +\n `Change the stride and/or zero pad parameters`);\n\n const outputRows = conditionalRound(\n (inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);\n util.assert(\n util.isInt(outputRows),\n () => `The output # of rows (${outputRows}) must be an integer. ` +\n `Change the stride and/or zero pad parameters`);\n\n const outputCols = conditionalRound(\n (inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);\n util.assert(\n util.isInt(outputCols),\n () => `The output # of columns (${outputCols}) must be an integer. ` +\n `Change the stride and/or zero pad parameters`);\n\n return [outputDepths, outputRows, outputCols, outChannels];\n}\n\nexport function computeDefaultPad(\n inputShape: [number, number]|[number, number, number, number],\n fieldSize: number, stride: number, dilation = 1): number {\n const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);\n return Math.floor(\n (inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);\n}\n\nfunction parseTupleParam(param: number|number[]): [number, number, number] {\n if (typeof param === 'number') {\n return [param, param, param];\n }\n if (param.length === 2) {\n return [param[0], param[1], 1];\n }\n return param as [number, number, number];\n}\n\nfunction parse3TupleParam(param: number|[number, number, number]):\n [number, number, number] {\n return typeof param === 'number' ? [param, param, param] : param;\n}\n\n/* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d\n * Atrous convolution is equivalent to standard convolution with upsampled\n * filters with effective_filter_height =\n * filter_height + (filter_height - 1) * (dilation - 1)\n * and effective_filter_width =\n * filter_width + (filter_width - 1) * (dilation - 1),\n * produced by inserting dilation - 1 zeros along consecutive elements across\n * the filters' spatial dimensions.\n * When there is a dilation, this converts a filter dimension to the\n * effective filter dimension, so it can be used in a standard convolution.\n */\nfunction getEffectiveFilterSize(filterSize: number, dilation: number) {\n if (dilation <= 1) {\n return filterSize;\n }\n\n return filterSize + (filterSize - 1) * (dilation - 1);\n}\n\nfunction getPadAndOutInfo(\n pad: 'same'|'valid'|number, inHeight: number, inWidth: number,\n strideHeight: number, strideWidth: number, filterHeight: number,\n filterWidth: number, roundingMode?: 'floor'|'round'|'ceil'):\n {padInfo: PadInfo, outHeight: number, outWidth: number} {\n let padInfo: PadInfo;\n let outHeight: number;\n let outWidth: number;\n\n if (typeof pad === 'number') {\n const padType = (pad === 0) ? 'VALID' : 'NUMBER';\n padInfo = {top: pad, bottom: pad, left: pad, right: pad, type: padType};\n const outShape = computeOutputShape2D(\n [inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);\n outHeight = outShape[0];\n outWidth = outShape[1];\n } else if (pad === 'same') {\n outHeight = Math.ceil(inHeight / strideHeight);\n outWidth = Math.ceil(inWidth / strideWidth);\n const padAlongHeight =\n Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);\n const padAlongWidth =\n Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);\n const top = Math.floor(padAlongHeight / 2);\n const bottom = padAlongHeight - top;\n const left = Math.floor(padAlongWidth / 2);\n const right = padAlongWidth - left;\n padInfo = {top, bottom, left, right, type: 'SAME'};\n } else if (pad === 'valid') {\n padInfo = {top: 0, bottom: 0, left: 0, right: 0, type: 'VALID'};\n outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);\n outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);\n } else {\n throw Error(`Unknown padding parameter: ${pad}`);\n }\n return {padInfo, outHeight, outWidth};\n}\n\nfunction get3DPadAndOutInfo(\n pad: 'same'|'valid'|number, inDepth: number, inHeight: number,\n inWidth: number, strideDepth: number, strideHeight: number,\n strideWidth: number, filterDepth: number, filterHeight: number,\n filterWidth: number, roundingMode?: 'floor'|'round'|'ceil'): {\n padInfo: PadInfo3D,\n outDepth: number,\n outHeight: number,\n outWidth: number\n} {\n let padInfo: PadInfo3D;\n let outDepth: number;\n let outHeight: number;\n let outWidth: number;\n\n if (typeof pad === 'number') {\n const padType = (pad === 0) ? 'VALID' : 'NUMBER';\n padInfo = {\n top: pad,\n bottom: pad,\n left: pad,\n right: pad,\n front: pad,\n back: pad,\n type: padType\n };\n const outShape = computeOutputShape4D(\n [inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad,\n roundingMode);\n outDepth = outShape[0];\n outHeight = outShape[1];\n outWidth = outShape[2];\n } else if (pad === 'same') {\n outDepth = Math.ceil(inDepth / strideDepth);\n outHeight = Math.ceil(inHeight / strideHeight);\n outWidth = Math.ceil(inWidth / strideWidth);\n const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;\n const padAlongHeight =\n (outHeight - 1) * strideHeight + filterHeight - inHeight;\n const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;\n const front = Math.floor(padAlongDepth / 2);\n const back = padAlongDepth - front;\n const top = Math.floor(padAlongHeight / 2);\n const bottom = padAlongHeight - top;\n const left = Math.floor(padAlongWidth / 2);\n const right = padAlongWidth - left;\n\n padInfo = {top, bottom, left, right, front, back, type: 'SAME'};\n } else if (pad === 'valid') {\n padInfo = {\n top: 0,\n bottom: 0,\n left: 0,\n right: 0,\n front: 0,\n back: 0,\n type: 'VALID'\n };\n outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);\n outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);\n outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);\n } else {\n throw Error(`Unknown padding parameter: ${pad}`);\n }\n return {padInfo, outDepth, outHeight, outWidth};\n}\n\n/**\n * Rounds a value depending on the rounding mode\n * @param value\n * @param roundingMode\n */\nfunction conditionalRound(\n value: number, roundingMode?: 'floor'|'round'|'ceil') {\n if (!roundingMode) {\n return value;\n }\n switch (roundingMode) {\n case 'round':\n // used for Caffe Conv\n return Math.round(value);\n case 'ceil':\n // used for Caffe Pool\n return Math.ceil(value);\n case 'floor':\n return Math.floor(value);\n default:\n throw new Error(`Unknown roundingMode ${roundingMode}`);\n }\n}\n\nexport function tupleValuesAreOne(param: number|number[]): boolean {\n const [dimA, dimB, dimC] = parseTupleParam(param);\n return dimA === 1 && dimB === 1 && dimC === 1;\n}\n\nexport function eitherStridesOrDilationsAreOne(\n strides: number|number[], dilations: number|number[]): boolean {\n return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);\n}\n\n/**\n * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to\n * 'channelsLast'|'channelsFirst'\n * @param dataFormat in 'NHWC'|'NCHW' mode\n * @return dataFormat in 'channelsLast'|'channelsFirst' mode\n * @throws unknown dataFormat\n */\nexport function convertConv2DDataFormat(dataFormat: 'NHWC'|'NCHW'):\n 'channelsLast'|'channelsFirst' {\n if (dataFormat === 'NHWC') {\n return 'channelsLast';\n } else if (dataFormat === 'NCHW') {\n return 'channelsFirst';\n } else {\n throw new Error(`Unknown dataFormat ${dataFormat}`);\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {scalar, tensor1d, zeros} from '../ops/tensor_ops';\nimport {Tensor} from '../tensor';\nimport {Rank} from '../types';\nimport {DataType, ShapeMap} from '../types';\nimport {hasEncodingLoss, makeZerosTypedArray} from '../util';\n\nimport {KernelBackend} from './backend';\n\n// Utilities needed by backend consumers of tf-core.\nexport * from '../ops/axis_util';\nexport * from '../ops/broadcast_util';\nexport * from '../ops/concat_util';\nexport * from '../ops/conv_util';\nexport {Activation, FusedConv2DConfig} from '../ops/fused_util';\nexport * from '../ops/reduce_util';\nexport {BackendValues, TypedArray, upcastType, PixelData} from '../types';\nexport {MemoryInfo, TimingInfo} from '../engine';\n\nexport function castTensor(\n x: T, dtype: DataType, backend: KernelBackend): T {\n if (dtype === 'complex64') {\n if (x.dtype === 'complex64') {\n return x.clone();\n }\n const zerosTensor = zeros(x.shape);\n const floatX = x.toFloat();\n const result = backend.complex(floatX, zerosTensor);\n zerosTensor.dispose();\n floatX.dispose();\n return result as T;\n }\n\n if (!hasEncodingLoss(x.dtype, dtype)) {\n // We don't change the underlying data, since we cast to higher\n // precision.\n return ENGINE.makeTensorFromDataId(x.dataId, x.shape, dtype) as T;\n }\n if (x.dtype === 'complex64') {\n const real = backend.real(x);\n const result = real.cast(dtype);\n real.dispose();\n return result;\n }\n if (dtype === 'int32') {\n return backend.int(x);\n } else if (dtype === 'bool') {\n const zero = scalar(0, x.dtype);\n const result = backend.notEqual(x, zero) as T;\n zero.dispose();\n return result;\n } else {\n throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);\n }\n}\n\nexport function reshapeTensor(\n x: T, shape: ShapeMap[R]): Tensor {\n return ENGINE.makeTensorFromDataId(x.dataId, shape, x.dtype) as Tensor;\n}\n\nexport function linspaceImpl(start: number, stop: number, num: number) {\n const step = (stop - start) / (num - 1);\n\n const values = makeZerosTypedArray(num, 'float32');\n values[0] = start;\n for (let i = 1; i < values.length; i++) {\n values[i] = values[i - 1] + step;\n }\n\n return tensor1d(values, 'float32');\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {TypedArray} from '../types';\n/**\n * Merges real and imaginary Float32Arrays into a single complex Float32Array.\n *\n * The memory layout is interleaved as follows:\n * real: [r0, r1, r2]\n * imag: [i0, i1, i2]\n * complex: [r0, i0, r1, i1, r2, i2]\n *\n * This is the inverse of splitRealAndImagArrays.\n *\n * @param real The real values of the complex tensor values.\n * @param imag The imag values of the complex tensor values.\n * @returns A complex tensor as a Float32Array with merged values.\n */\nexport function mergeRealAndImagArrays(\n real: Float32Array, imag: Float32Array): Float32Array {\n if (real.length !== imag.length) {\n throw new Error(\n `Cannot merge real and imag arrays of different lengths. real:` +\n `${real.length}, imag: ${imag.length}.`);\n }\n const result = new Float32Array(real.length * 2);\n for (let i = 0; i < result.length; i += 2) {\n result[i] = real[i / 2];\n result[i + 1] = imag[i / 2];\n }\n return result;\n}\n\n/**\n * Splits a complex Float32Array into real and imag parts.\n *\n * The memory layout is interleaved as follows:\n * complex: [r0, i0, r1, i1, r2, i2]\n * real: [r0, r1, r2]\n * imag: [i0, i1, i2]\n *\n * This is the inverse of mergeRealAndImagArrays.\n *\n * @param complex The complex tensor values.\n * @returns An object with real and imag Float32Array components of the complex\n * tensor.\n */\nexport function splitRealAndImagArrays(complex: Float32Array):\n {real: Float32Array, imag: Float32Array} {\n const real = new Float32Array(complex.length / 2);\n const imag = new Float32Array(complex.length / 2);\n for (let i = 0; i < complex.length; i += 2) {\n real[i / 2] = complex[i];\n imag[i / 2] = complex[i + 1];\n }\n return {real, imag};\n}\n\n/**\n * Extracts even indexed complex values in the given array.\n * @param complex The complex tensor values\n */\nexport function complexWithEvenIndex(complex: Float32Array):\n {real: Float32Array, imag: Float32Array} {\n const len = Math.ceil(complex.length / 4);\n const real = new Float32Array(len);\n const imag = new Float32Array(len);\n for (let i = 0; i < complex.length; i += 4) {\n real[Math.floor(i / 4)] = complex[i];\n imag[Math.floor(i / 4)] = complex[i + 1];\n }\n return {real, imag};\n}\n\n/**\n * Extracts odd indexed comple values in the given array.\n * @param complex The complex tensor values\n */\nexport function complexWithOddIndex(complex: Float32Array):\n {real: Float32Array, imag: Float32Array} {\n const len = Math.floor(complex.length / 4);\n const real = new Float32Array(len);\n const imag = new Float32Array(len);\n for (let i = 2; i < complex.length; i += 4) {\n real[Math.floor(i / 4)] = complex[i];\n imag[Math.floor(i / 4)] = complex[i + 1];\n }\n return {real, imag};\n}\n\n/**\n * Get the map representing a complex value in the given array.\n * @param complex The complex tensor values.\n * @param index An index of the target complex value.\n */\nexport function getComplexWithIndex(\n complex: Float32Array, index: number): {real: number, imag: number} {\n const real = complex[index * 2];\n const imag = complex[index * 2 + 1];\n return {real, imag};\n}\n\n/**\n * Insert a given complex value into the TypedArray.\n * @param data The array in which the complex value is inserted.\n * @param c The complex value to be inserted.\n * @param index An index of the target complex value.\n */\nexport function assignToTypedArray(\n data: TypedArray, real: number, imag: number, index: number) {\n data[index * 2] = real;\n data[index * 2 + 1] = imag;\n}\n\n/**\n * Make the list of exponent terms used by FFT.\n */\nexport function exponents(\n n: number, inverse: boolean): {real: Float32Array, imag: Float32Array} {\n const real = new Float32Array(n / 2);\n const imag = new Float32Array(n / 2);\n for (let i = 0; i < Math.ceil(n / 2); i++) {\n const x = (inverse ? 2 : -2) * Math.PI * (i / n);\n real[i] = Math.cos(x);\n imag[i] = Math.sin(x);\n }\n return {real, imag};\n}\n\n/**\n * Make the exponent term used by FFT.\n */\nexport function exponent(\n k: number, n: number, inverse: boolean): {real: number, imag: number} {\n const x = (inverse ? 2 : -2) * Math.PI * (k / n);\n const real = Math.cos(x);\n const imag = Math.sin(x);\n return {real, imag};\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Inserts a value into a sorted array. This method allows duplicate, meaning it\n * allows inserting duplicate value, in which case, the element will be inserted\n * at the lowest index of the value.\n * @param arr The array to modify.\n * @param element The element to insert.\n * @param comparator Optional. If no comparator is specified, elements are\n * compared using array_util.defaultComparator, which is suitable for Strings\n * and Numbers in ascending arrays. If the array contains multiple instances of\n * the target value, the left-most instance will be returned. To provide a\n * comparator, it should take 2 arguments to compare and return a negative,\n * zero, or a positive number.\n */\nexport function binaryInsert(\n arr: T[], element: T, comparator?: (a: T, b: T) => number) {\n const index = binarySearch(arr, element, comparator);\n const insertionPoint = index < 0 ? -(index + 1) : index;\n arr.splice(insertionPoint, 0, element);\n}\n\n/**\n * Searches the array for the target using binary search, returns the index\n * of the found element, or position to insert if element not found. If no\n * comparator is specified, elements are compared using array_\n * util.defaultComparator, which is suitable for Strings and Numbers in\n * ascending arrays. If the array contains multiple instances of the target\n * value, the left-most instance will be returned.\n * @param arr The array to be searched in.\n * @param target The target to be searched for.\n * @param comparator Should take 2 arguments to compare and return a negative,\n * zero, or a positive number.\n * @return Lowest index of the target value if found, otherwise the insertion\n * point where the target should be inserted, in the form of\n * (-insertionPoint - 1).\n */\nexport function binarySearch(\n arr: T[], target: T, comparator?: (a: T, b: T) => number) {\n return binarySearch_(arr, target, comparator || defaultComparator);\n}\n\n/**\n * Compares its two arguments for order.\n * @param a The first element to be compared.\n * @param b The second element to be compared.\n * @return A negative number, zero, or a positive number as the first\n * argument is less than, equal to, or greater than the second.\n */\nfunction defaultComparator(a: T, b: T): number {\n return a > b ? 1 : a < b ? -1 : 0;\n}\n\nfunction binarySearch_(\n arr: T[], target: T, comparator: (a: T, b: T) => number) {\n let left = 0;\n let right = arr.length;\n let middle = 0;\n let found = false;\n while (left < right) {\n middle = left + ((right - left) >>> 1);\n const compareResult = comparator(target, arr[middle]);\n if (compareResult > 0) {\n left = middle + 1;\n } else {\n right = middle;\n // If compareResult is 0, the value is found. We record it is found,\n // and then keep looking because there may be duplicate.\n found = !compareResult;\n }\n }\n\n return found ? left : -left - 1;\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Implementation of the NonMaxSuppression kernel shared between webgl and cpu.\n */\n\nimport {scalar, tensor1d} from '../ops/tensor_ops';\nimport {Tensor1D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {TypedArray} from '../types';\n\nimport {binaryInsert} from './array_util';\n\ninterface Candidate {\n score: number;\n boxIndex: number;\n suppressBeginIndex: number;\n}\n\nexport function nonMaxSuppressionV3(\n boxes: TypedArray, scores: TypedArray, maxOutputSize: number,\n iouThreshold: number, scoreThreshold: number): Tensor1D {\n const dummySoftNmsSigma = 0.0;\n\n return nonMaxSuppressionImpl_(\n boxes, scores, maxOutputSize, iouThreshold, scoreThreshold,\n dummySoftNmsSigma)\n .selectedIndices as Tensor1D;\n}\n\nexport function nonMaxSuppressionV5(\n boxes: TypedArray, scores: TypedArray, maxOutputSize: number,\n iouThreshold: number, scoreThreshold: number,\n softNmsSigma: number): NamedTensorMap {\n // For NonMaxSuppressionV5Op, we always return a second output holding\n // corresponding scores.\n const returnScoresTensor = true;\n\n const result = nonMaxSuppressionImpl_(\n boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma,\n returnScoresTensor);\n\n result.numValidOutputs.dispose();\n\n return {\n selectedIndices: result.selectedIndices,\n selectedScores: result.selectedScores\n };\n}\n\nfunction nonMaxSuppressionImpl_(\n boxes: TypedArray, scores: TypedArray, maxOutputSize: number,\n iouThreshold: number, scoreThreshold: number, softNmsSigma: number,\n returnScoresTensor = false, padToMaxOutputSize = false): NamedTensorMap {\n // The list is sorted in ascending order, so that we can always pop the\n // candidate with the largest score in O(1) time.\n const candidates =\n Array.from(scores)\n .map((score, boxIndex) => ({score, boxIndex, suppressBeginIndex: 0}))\n .filter(c => c.score > scoreThreshold)\n .sort(ascendingComparator);\n\n // If softNmsSigma is 0, the outcome of this algorithm is exactly same as\n // before.\n const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;\n\n const selectedIndices: number[] = [];\n const selectedScores: number[] = [];\n\n while (selectedIndices.length < maxOutputSize && candidates.length > 0) {\n const candidate = candidates.pop();\n const {score: originalScore, boxIndex, suppressBeginIndex} = candidate;\n\n if (originalScore < scoreThreshold) {\n break;\n }\n\n // Overlapping boxes are likely to have similar scores, therefore we\n // iterate through the previously selected boxes backwards in order to\n // see if candidate's score should be suppressed. We use\n // suppressBeginIndex to track and ensure a candidate can be suppressed\n // by a selected box no more than once. Also, if the overlap exceeds\n // iouThreshold, we simply ignore the candidate.\n let ignoreCandidate = false;\n for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {\n const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);\n\n if (iou >= iouThreshold) {\n ignoreCandidate = true;\n break;\n }\n\n candidate.score =\n candidate.score * suppressWeight(iouThreshold, scale, iou);\n\n if (candidate.score <= scoreThreshold) {\n break;\n }\n }\n\n // At this point, if `candidate.score` has not dropped below\n // `scoreThreshold`, then we know that we went through all of the\n // previous selections and can safely update `suppressBeginIndex` to the\n // end of the selected array. Then we can re-insert the candidate with\n // the updated score and suppressBeginIndex back in the candidate list.\n // If on the other hand, `candidate.score` has dropped below the score\n // threshold, we will not add it back to the candidates list.\n candidate.suppressBeginIndex = selectedIndices.length;\n\n if (!ignoreCandidate) {\n // Candidate has passed all the tests, and is not suppressed, so\n // select the candidate.\n if (candidate.score === originalScore) {\n selectedIndices.push(boxIndex);\n selectedScores.push(candidate.score);\n } else if (candidate.score > scoreThreshold) {\n // Candidate's score is suppressed but is still high enough to be\n // considered, so add back to the candidates list.\n binaryInsert(candidates, candidate, ascendingComparator);\n }\n }\n }\n\n // NonMaxSuppressionV4 feature: padding output to maxOutputSize.\n const numValidOutputs = selectedIndices.length;\n if (padToMaxOutputSize) {\n selectedIndices.fill(0, numValidOutputs);\n selectedScores.fill(0.0, numValidOutputs);\n }\n\n return {\n selectedIndices: tensor1d(selectedIndices, 'int32'),\n selectedScores: tensor1d(selectedScores, 'float32'),\n numValidOutputs: scalar(numValidOutputs, 'int32')\n };\n}\n\nfunction intersectionOverUnion(boxes: TypedArray, i: number, j: number) {\n const iCoord = boxes.subarray(i * 4, i * 4 + 4);\n const jCoord = boxes.subarray(j * 4, j * 4 + 4);\n const yminI = Math.min(iCoord[0], iCoord[2]);\n const xminI = Math.min(iCoord[1], iCoord[3]);\n const ymaxI = Math.max(iCoord[0], iCoord[2]);\n const xmaxI = Math.max(iCoord[1], iCoord[3]);\n const yminJ = Math.min(jCoord[0], jCoord[2]);\n const xminJ = Math.min(jCoord[1], jCoord[3]);\n const ymaxJ = Math.max(jCoord[0], jCoord[2]);\n const xmaxJ = Math.max(jCoord[1], jCoord[3]);\n const areaI = (ymaxI - yminI) * (xmaxI - xminI);\n const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);\n if (areaI <= 0 || areaJ <= 0) {\n return 0.0;\n }\n const intersectionYmin = Math.max(yminI, yminJ);\n const intersectionXmin = Math.max(xminI, xminJ);\n const intersectionYmax = Math.min(ymaxI, ymaxJ);\n const intersectionXmax = Math.min(xmaxI, xmaxJ);\n const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *\n Math.max(intersectionXmax - intersectionXmin, 0.0);\n return intersectionArea / (areaI + areaJ - intersectionArea);\n}\n\n// A Gaussian penalty function, this method always returns values in [0, 1].\n// The weight is a function of similarity, the more overlap two boxes are, the\n// smaller the weight is, meaning highly overlapping boxe will be significantly\n// penalized. On the other hand, a non-overlapping box will not be penalized.\nfunction suppressWeight(iouThreshold: number, scale: number, iou: number) {\n const weight = Math.exp(scale * iou * iou);\n return iou <= iouThreshold ? weight : 0.0;\n}\n\nfunction ascendingComparator(c1: Candidate, c2: Candidate) {\n // For objects with same scores, we make the object with the larger index go\n // first. In an array that pops from the end, this means that the object with\n // the smaller index will be popped first. This ensures the same output as\n // the TensorFlow python version.\n return (c1.score - c2.score) ||\n ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\n\n/** Shared implementation of the split kernel across WebGL and CPU. */\nexport function split(\n x: T, sizeSplits: number[], axis: number): T[] {\n const begin = new Array(x.rank).fill(0);\n const size = x.shape.slice();\n return sizeSplits.map(s => {\n size[axis] = s;\n const slice = x.slice(begin, size);\n begin[axis] += s;\n return slice;\n });\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * An implementation of the tile kernel shared between webgl and cpu for string\n * tensors only.\n */\n\nimport {buffer} from '../ops/array_ops';\nimport {Tensor, TensorBuffer} from '../tensor';\nimport {DataType, Rank} from '../types';\n\nexport function tile(\n xBuf: TensorBuffer, reps: number[]): Tensor {\n const newShape: number[] = new Array(xBuf.rank);\n for (let i = 0; i < newShape.length; i++) {\n newShape[i] = xBuf.shape[i] * reps[i];\n }\n const result = buffer(newShape, xBuf.dtype);\n for (let i = 0; i < result.values.length; ++i) {\n const newLoc = result.indexToLoc(i);\n\n const originalLoc: number[] = new Array(xBuf.rank);\n for (let j = 0; j < originalLoc.length; j++) {\n originalLoc[j] = newLoc[j] % xBuf.shape[j];\n }\n\n const originalIndex = xBuf.locToIndex(originalLoc);\n\n result.values[i] = xBuf.values[originalIndex];\n }\n return result.toTensor() as Tensor;\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/** An implementation of the TopK kernel shared between webgl and cpu. */\n\nimport {tensor} from '../ops/tensor_ops';\nimport {Tensor} from '../tensor';\nimport {NumericDataType, TypedArray} from '../types';\nimport {getTypedArrayFromDType} from '../util';\n\nexport function topkImpl(\n x: TypedArray, xShape: number[], xDtype: NumericDataType, k: number,\n sorted: boolean): [T, T] {\n // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.\n const lastDim = xShape[xShape.length - 1];\n const [batch, size] = [x.length / lastDim, lastDim];\n const allTopKVals = getTypedArrayFromDType(xDtype, batch * k);\n const allTopKIndices = getTypedArrayFromDType('int32', batch * k);\n\n for (let b = 0; b < batch; b++) {\n const offset = b * size;\n const vals = x.subarray(offset, offset + size);\n const valAndInd: Array<{value: number, index: number}> = [];\n for (let i = 0; i < vals.length; i++) {\n valAndInd.push({value: vals[i], index: i});\n }\n valAndInd.sort((a, b) => b.value - a.value);\n\n const outOffset = b * k;\n const topKVals = allTopKVals.subarray(outOffset, outOffset + k);\n const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);\n for (let i = 0; i < k; i++) {\n topKVals[i] = valAndInd[i].value;\n topKIndices[i] = valAndInd[i].index;\n }\n }\n // Reshape back to the original input shape, except that the last\n // dimension is k.\n const outputShape = xShape.slice();\n outputShape[outputShape.length - 1] = k;\n return [\n tensor(allTopKVals, outputShape, xDtype) as T,\n tensor(allTopKIndices, outputShape, 'int32') as T\n ];\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/** An implementation of the Where kernel shared between cpu and webgl */\n\nimport {buffer} from '../ops/array_ops';\nimport {Tensor2D} from '../tensor';\nimport {TypedArray} from '../types';\n\nexport function whereImpl(condShape: number[], condVals: TypedArray): Tensor2D {\n const indices = [];\n for (let i = 0; i < condVals.length; i++) {\n if (condVals[i]) {\n indices.push(i);\n }\n }\n\n const inBuffer = buffer(condShape, 'int32');\n\n const out = buffer([indices.length, condShape.length], 'int32');\n for (let i = 0; i < indices.length; i++) {\n const loc = inBuffer.indexToLoc(indices[i]);\n const offset = i * condShape.length;\n out.values.set(loc, offset);\n }\n return out.toTensor() as Tensor2D;\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class AddNProgram implements GPGPUProgram {\n variableNames: string[];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(outputShape: number[], shapes: number[][]) {\n this.outputShape = outputShape;\n this.variableNames = shapes.map((_, i) => `T${i}`);\n\n const snippets: string[] = [];\n // Get target elements from every input tensor.\n this.variableNames.forEach(variable => {\n snippets.push(`float v${variable} = get${variable}AtOutCoords();`);\n });\n\n // Calculate the sum of all elements.\n const operation = this.variableNames\n .map(variable => {\n return `v${variable}`;\n })\n .join(' + ');\n\n this.userCode = `\n void main() {\n ${snippets.join('\\n ')}\n\n float result = ${operation};\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class AddNPackedProgram implements GPGPUProgram {\n variableNames: string[];\n outputShape: number[] = [];\n userCode: string;\n packedInputs = true;\n packedOutput = true;\n\n constructor(outputShape: number[], shapes: number[][]) {\n this.outputShape = outputShape;\n this.variableNames = shapes.map((_, i) => `T${i}`);\n\n const snippets: string[] = [];\n // Get target elements from every input tensor.\n this.variableNames.forEach(variable => {\n snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);\n });\n\n // Calculate the sum of all elements.\n const operation = this.variableNames\n .map(variable => {\n return `v${variable}`;\n })\n .join(' + ');\n\n this.userCode = `\n void main() {\n ${snippets.join('\\n ')}\n\n vec4 result = ${operation};\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ReduceInfo} from '../../ops/reduce_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ArgMinMaxProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[];\n userCode: string;\n\n constructor(reduceInfo: ReduceInfo, op: 'max'|'min', firstPass: boolean) {\n const windowSize = reduceInfo.windowSize;\n const batchSize = reduceInfo.batchSize;\n const inSize = reduceInfo.inSize;\n const outSize = Math.ceil(inSize / windowSize);\n if (!firstPass) {\n this.variableNames.push('bestIndicesA');\n }\n this.outputShape = [batchSize, outSize];\n const compOp = (op === 'max') ? '>' : '<';\n const indexSnippet = firstPass ?\n 'inOffset + i;' :\n 'round(getBestIndicesA(batch, inOffset + i));';\n\n this.userCode = `\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ${windowSize};\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < ${windowSize}; i++) {\n int inIdx = ${indexSnippet};\n float candidate = getA(batch, inIdx);\n if (candidate ${compOp} bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nexport function getVecChannels(name: string, rank: number): string[] {\n return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);\n}\n\nexport function getChannels(name: string, rank: number): string[] {\n if (rank === 1) {\n return [name];\n }\n return getVecChannels(name, rank);\n}\n\nexport function getSourceCoords(rank: number, dims: string[]): string {\n if (rank === 1) {\n return 'rc';\n }\n\n let coords = '';\n for (let i = 0; i < rank; i++) {\n coords += dims[i];\n if (i < rank - 1) {\n coords += ',';\n }\n }\n return coords;\n}","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {env} from '../../environment';\n\nexport type GLSL = {\n version: string,\n attribute: string,\n varyingVs: string,\n varyingFs: string,\n texture2D: string,\n output: string,\n defineOutput: string,\n defineSpecialNaN: string,\n defineSpecialInf: string,\n defineRound: string\n};\n\nexport function getGlslDifferences(): GLSL {\n let version: string;\n let attribute: string;\n let varyingVs: string;\n let varyingFs: string;\n let texture2D: string;\n let output: string;\n let defineOutput: string;\n let defineSpecialNaN: string;\n let defineSpecialInf: string;\n let defineRound: string;\n\n if (env().getNumber('WEBGL_VERSION') === 2) {\n version = '#version 300 es';\n attribute = 'in';\n varyingVs = 'out';\n varyingFs = 'in';\n texture2D = 'texture';\n output = 'outputColor';\n defineOutput = 'out vec4 outputColor;';\n\n // Use custom isnan definition to work across differences between\n // implementations on various platforms. While this should happen in ANGLE\n // we still see differences between android and windows (on chrome) when\n // using isnan directly.\n defineSpecialNaN = `\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n `;\n // In webgl 2 we do not need to specify a custom isinf so there is no\n // need for a special INFINITY constant.\n defineSpecialInf = ``;\n defineRound = `\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n `;\n } else {\n version = '';\n attribute = 'attribute';\n varyingVs = 'varying';\n varyingFs = 'varying';\n texture2D = 'texture2D';\n output = 'gl_FragColor';\n defineOutput = '';\n // WebGL1 has no built in isnan so we define one here.\n defineSpecialNaN = `\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n `;\n defineSpecialInf = `\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n `;\n defineRound = `\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n `;\n }\n\n return {\n version,\n attribute,\n varyingVs,\n varyingFs,\n texture2D,\n output,\n defineOutput,\n defineSpecialNaN,\n defineSpecialInf,\n defineRound\n };\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as util from '../../util';\n\n/**\n * Produces GLSL code that derives logical coordinates from a flat\n * index. The code performs integer division with each stride and decrements\n * the index until the index equals the final dimension coordinate.\n */\nexport function getLogicalCoordinatesFromFlatIndex(\n coords: string[], shape: number[], index = 'index'): string {\n const strides = util.computeStrides(shape);\n return strides\n .map((stride, i) => {\n const line1 = `int ${coords[i]} = ${index} / ${stride}`;\n const line2 = i === strides.length - 1 ?\n `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :\n `index -= ${coords[i]} * ${stride}`;\n return `${line1}; ${line2};`;\n })\n .join('');\n}\n\nfunction buildVec(x: string[]): string {\n if (x.length === 1) {\n return `${x[0]}`;\n }\n return `vec${x.length}(${x.join(',')})`;\n}\n\n/**\n * Produces GLSL code that computes the dot product of the input x and y\n * vectors. Handles splitting inputs into increments of vec4s when necessary.\n */\nexport function dotify(x: string[], y: string[]): string {\n if (x.length !== y.length) {\n throw new Error(\n `Vectors to be dotted must be of the same length -` +\n `got ${x.length} and ${y.length}`);\n }\n\n const slices: string[] = [];\n const nearestVec4 = Math.floor(x.length / 4);\n const nearestVec4Remainder = x.length % 4;\n\n for (let i = 0; i < nearestVec4; i++) {\n const xSlice = x.slice(i * 4, i * 4 + 4);\n const ySlice = y.slice(i * 4, i * 4 + 4);\n slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);\n }\n\n if (nearestVec4Remainder !== 0) {\n let xSlice = x.slice(nearestVec4 * 4);\n let ySlice = y.slice(nearestVec4 * 4);\n if (xSlice.length === 1) {\n xSlice = xSlice.map(d => `float(${d})`);\n ySlice = ySlice.map(d => `float(${d})`);\n }\n slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);\n }\n\n return slices.map((d, i) => `dot(${d})`).join('+');\n}\n\n/**\n * Produces GLSL that computes the flat index from 3D coordinates.\n */\nexport function getFlatIndexFrom3D(shape: [number, number, number]): string {\n const strides = util.computeStrides(shape).map(d => d.toString());\n\n return `\n int getFlatIndex(ivec3 coords) {\n return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;\n }\n`;\n}\n\nexport const ENCODE_FLOAT_SNIPPET = `\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n`;","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getBroadcastDims} from '../../ops/broadcast_util';\nimport * as util from '../../util';\nimport {getGlslDifferences, GLSL} from './glsl_version';\nimport * as shader_util from './shader_compiler_util';\n\nexport type ShapeInfo = {\n logicalShape: number[],\n texShape: [number, number],\n isUniform: boolean,\n isPacked: boolean,\n flatOffset: number\n};\n\nexport type InputInfo = {\n name: string,\n shapeInfo: ShapeInfo\n};\n\nexport function makeShader(\n inputsInfo: InputInfo[], outputShape: ShapeInfo, userCode: string,\n usesPackedTextures: boolean): string {\n const prefixSnippets: string[] = [];\n inputsInfo.forEach(x => {\n const size = util.sizeFromShape(x.shapeInfo.logicalShape);\n\n // Snippet when we decided to upload the values as uniform.\n if (x.shapeInfo.isUniform) {\n prefixSnippets.push(\n `uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`);\n } else {\n prefixSnippets.push(`uniform sampler2D ${x.name};`);\n prefixSnippets.push(`uniform int offset${x.name};`);\n }\n });\n const inputPrefixSnippet = prefixSnippets.join('\\n');\n\n const inputSamplingSnippet =\n inputsInfo\n .map(x => getInputSamplingSnippet(x, outputShape, usesPackedTextures))\n .join('\\n');\n const outTexShape = outputShape.texShape;\n const glsl = getGlslDifferences();\n const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);\n let outputSamplingSnippet: string;\n let floatTextureSetOutputSnippet: string;\n let shaderPrefix = getShaderPrefix(glsl);\n\n if (outputShape.isPacked) {\n outputSamplingSnippet =\n getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape);\n floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);\n } else {\n outputSamplingSnippet =\n getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);\n floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);\n }\n\n if (usesPackedTextures) {\n shaderPrefix += SHADER_PACKED_PREFIX;\n }\n\n const source = [\n shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,\n inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode\n ].join('\\n');\n return source;\n}\n\nfunction getSamplerFromInInfo(inInfo: InputInfo): string {\n const shape = inInfo.shapeInfo.logicalShape;\n switch (shape.length) {\n case 0:\n return getSamplerScalar(inInfo);\n case 1:\n return getSampler1D(inInfo);\n case 2:\n return getSampler2D(inInfo);\n case 3:\n return getSampler3D(inInfo);\n case 4:\n return getSampler4D(inInfo);\n case 5:\n return getSampler5D(inInfo);\n case 6:\n return getSampler6D(inInfo);\n default:\n throw new Error(\n `${shape.length}-D input sampling` +\n ` is not yet supported`);\n }\n}\n\nfunction getPackedSamplerFromInInfo(inInfo: InputInfo): string {\n const shape = inInfo.shapeInfo.logicalShape;\n switch (shape.length) {\n case 0:\n return getPackedSamplerScalar(inInfo);\n case 1:\n return getPackedSampler1D(inInfo);\n case 2:\n return getPackedSampler2D(inInfo);\n case 3:\n return getPackedSampler3D(inInfo);\n default:\n return getPackedSamplerND(inInfo);\n }\n}\n\nfunction getInputSamplingSnippet(\n inInfo: InputInfo, outShapeInfo: ShapeInfo,\n usesPackedTextures = false): string {\n let res = '';\n if (usesPackedTextures) {\n res += getPackedSamplerFromInInfo(inInfo);\n } else {\n res += getSamplerFromInInfo(inInfo);\n }\n\n const inShape = inInfo.shapeInfo.logicalShape;\n const outShape = outShapeInfo.logicalShape;\n if (inShape.length <= outShape.length) {\n if (usesPackedTextures) {\n res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);\n } else {\n res += getSamplerAtOutputCoords(inInfo, outShapeInfo);\n }\n }\n return res;\n}\n\nfunction getPackedOutputSamplingSnippet(\n outShape: number[], outTexShape: [number, number]): string {\n switch (outShape.length) {\n case 0:\n return getOutputScalarCoords();\n case 1:\n return getOutputPacked1DCoords(outShape as [number], outTexShape);\n case 2:\n return getOutputPacked2DCoords(outShape as [number, number], outTexShape);\n case 3:\n return getOutputPacked3DCoords(\n outShape as [number, number, number], outTexShape);\n default:\n return getOutputPackedNDCoords(outShape, outTexShape);\n }\n}\n\nfunction getOutputSamplingSnippet(\n outShape: number[], outTexShape: [number, number]): string {\n switch (outShape.length) {\n case 0:\n return getOutputScalarCoords();\n case 1:\n return getOutput1DCoords(outShape as [number], outTexShape);\n case 2:\n return getOutput2DCoords(outShape as [number, number], outTexShape);\n case 3:\n return getOutput3DCoords(\n outShape as [number, number, number], outTexShape);\n case 4:\n return getOutput4DCoords(\n outShape as [number, number, number, number], outTexShape);\n case 5:\n return getOutput5DCoords(\n outShape as [number, number, number, number, number], outTexShape);\n case 6:\n return getOutput6DCoords(\n outShape as [number, number, number, number, number, number],\n outTexShape);\n default:\n throw new Error(\n `${outShape.length}-D output sampling is not yet supported`);\n }\n}\n\nfunction getFloatTextureSampleSnippet(glsl: GLSL): string {\n return `\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return ${glsl.texture2D}(textureSampler, uv).r;\n }\n `;\n}\n\nfunction getFloatTextureSetRSnippet(glsl: GLSL): string {\n return `\n void setOutput(float val) {\n ${glsl.output} = vec4(val, 0, 0, 0);\n }\n `;\n}\n\nfunction getFloatTextureSetRGBASnippet(glsl: GLSL): string {\n return `\n void setOutput(vec4 val) {\n ${glsl.output} = val;\n }\n `;\n}\n\nfunction getShaderPrefix(glsl: GLSL): string {\n const SHADER_PREFIX = `${glsl.version}\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n ${glsl.varyingFs} vec2 resultUV;\n ${glsl.defineOutput}\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n ${glsl.defineSpecialNaN}\n ${glsl.defineSpecialInf}\n ${glsl.defineRound}\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n ${SAMPLE_1D_SNIPPET}\n ${SAMPLE_2D_SNIPPET}\n ${SAMPLE_3D_SNIPPET}\n `;\n\n return SHADER_PREFIX;\n}\n\nconst SAMPLE_1D_SNIPPET = `\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n`;\n\nconst SAMPLE_2D_SNIPPET = `\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n`;\n\nconst SAMPLE_3D_SNIPPET = `\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n`;\n\nconst SHADER_PACKED_PREFIX = `\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n`;\n\nfunction getOutputScalarCoords() {\n return `\n int getOutputCoords() {\n return 0;\n }\n `;\n}\n\nfunction getOutputPacked1DCoords(\n shape: [number], texShape: [number, number]): string {\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n if (packedTexShape[0] === 1) {\n return `\n int getOutputCoords() {\n return 2 * int(resultUV.x * ${packedTexShape[1]}.0);\n }\n `;\n }\n\n if (packedTexShape[1] === 1) {\n return `\n int getOutputCoords() {\n return 2 * int(resultUV.y * ${packedTexShape[0]}.0);\n }\n `;\n }\n\n return `\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${packedTexShape[0]}, ${packedTexShape[1]}));\n return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);\n }\n `;\n}\n\nfunction getOutput1DCoords(\n shape: [number], texShape: [number, number]): string {\n if (texShape[0] === 1) {\n return `\n int getOutputCoords() {\n return int(resultUV.x * ${texShape[1]}.0);\n }\n `;\n }\n if (texShape[1] === 1) {\n return `\n int getOutputCoords() {\n return int(resultUV.y * ${texShape[0]}.0);\n }\n `;\n }\n return `\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n return resTexRC.x * ${texShape[1]} + resTexRC.y;\n }\n `;\n}\n\nfunction getOutputPacked3DCoords(\n shape: [number, number, number], texShape: [number, number]): string {\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n const texelsInLogicalRow = Math.ceil(shape[2] / 2);\n const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);\n\n return `\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${packedTexShape[0]}, ${packedTexShape[1]}));\n int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;\n\n int b = index / ${texelsInBatch};\n index -= b * ${texelsInBatch};\n\n int r = 2 * (index / ${texelsInLogicalRow});\n int c = imod(index, ${texelsInLogicalRow}) * 2;\n\n return ivec3(b, r, c);\n }\n `;\n}\n\nfunction getOutput3DCoords(\n shape: [number, number, number], texShape: [number, number]): string {\n const coordsFromIndexSnippet =\n shader_util.getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);\n\n return `\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n ${coordsFromIndexSnippet}\n return ivec3(r, c, d);\n }\n `;\n}\n\nfunction getOutputPackedNDCoords(\n shape: number[], texShape: [number, number]): string {\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n\n const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);\n const texelsInBatch =\n texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);\n let texelsInBatchN = texelsInBatch;\n let batches = ``;\n let coords = 'b, r, c';\n\n for (let b = 2; b < shape.length - 1; b++) {\n texelsInBatchN *= shape[shape.length - b - 1];\n batches = `\n int b${b} = index / ${texelsInBatchN};\n index -= b${b} * ${texelsInBatchN};\n ` + batches;\n coords = `b${b}, ` + coords;\n }\n\n return `\n ivec${shape.length} getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${packedTexShape[0]}, ${packedTexShape[1]}));\n int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;\n\n ${batches}\n\n int b = index / ${texelsInBatch};\n index -= b * ${texelsInBatch};\n\n int r = 2 * (index / ${texelsInLogicalRow});\n int c = imod(index, ${texelsInLogicalRow}) * 2;\n\n return ivec${shape.length}(${coords});\n }\n `;\n}\n\nfunction getOutput4DCoords(\n shape: [number, number, number, number],\n texShape: [number, number]): string {\n const coordsFromIndexSnippet = shader_util.getLogicalCoordinatesFromFlatIndex(\n ['r', 'c', 'd', 'd2'], shape);\n\n return `\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n ${coordsFromIndexSnippet}\n return ivec4(r, c, d, d2);\n }\n `;\n}\n\nfunction getOutput5DCoords(\n shape: [number, number, number, number, number],\n texShape: [number, number]): string {\n const coordsFromIndexSnippet = shader_util.getLogicalCoordinatesFromFlatIndex(\n ['r', 'c', 'd', 'd2', 'd3'], shape);\n\n return `\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},\n ${texShape[1]}));\n\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n\n ${coordsFromIndexSnippet}\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n `;\n}\n\nfunction getOutput6DCoords(\n shape: [number, number, number, number, number, number],\n texShape: [number, number]): string {\n const coordsFromIndexSnippet = shader_util.getLogicalCoordinatesFromFlatIndex(\n ['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);\n\n return `\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n\n ${coordsFromIndexSnippet}\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n `;\n}\n\nfunction getOutputPacked2DCoords(\n shape: [number, number], texShape: [number, number]): string {\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n if (util.arraysEqual(shape, texShape)) {\n return `\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${\n packedTexShape[1]}));\n }\n `;\n }\n\n // texels needed to accommodate a logical row\n const texelsInLogicalRow = Math.ceil(shape[1] / 2);\n\n /**\n * getOutputCoords\n *\n * resTexRC: The rows and columns of the texels. If you move over one\n * texel to the right in the packed texture, you are moving over one column\n * (not two).\n *\n * index: The texel index\n */\n return `\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${packedTexShape[0]}, ${packedTexShape[1]}));\n\n int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;\n int r = 2 * (index / ${texelsInLogicalRow});\n int c = imod(index, ${texelsInLogicalRow}) * 2;\n\n return ivec2(r, c);\n }\n `;\n}\n\nfunction getOutput2DCoords(\n shape: [number, number], texShape: [number, number]): string {\n if (util.arraysEqual(shape, texShape)) {\n return `\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));\n }\n `;\n }\n if (shape[1] === 1) {\n return `\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n return ivec2(index, 0);\n }\n `;\n }\n if (shape[0] === 1) {\n return `\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n return ivec2(0, index);\n }\n `;\n }\n return `\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = resTexRC.x * ${texShape[1]} + resTexRC.y;\n int r = index / ${shape[1]};\n int c = index - r * ${shape[1]};\n return ivec2(r, c);\n }\n `;\n}\n\nfunction getFlatOffsetUniformName(texName: string): string {\n return `offset${texName}`;\n}\n\nfunction getPackedSamplerScalar(inputInfo: InputInfo): string {\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const glsl = getGlslDifferences();\n return `\n vec4 ${funcName}() {\n return ${glsl.texture2D}(${texName}, halfCR);\n }\n `;\n}\n\nfunction getSamplerScalar(inputInfo: InputInfo): string {\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n if (inputInfo.shapeInfo.isUniform) {\n return `float ${funcName}() {return ${texName};}`;\n }\n const [texNumR, texNumC] = inputInfo.shapeInfo.texShape;\n if (texNumR === 1 && texNumC === 1) {\n return `\n float ${funcName}() {\n return sampleTexture(${texName}, halfCR);\n }\n `;\n }\n\n const [tNumR, tNumC] = inputInfo.shapeInfo.texShape;\n const offset = getFlatOffsetUniformName(texName);\n return `\n float ${funcName}() {\n vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});\n return sampleTexture(${texName}, uv);\n }\n `;\n}\n\nfunction getPackedSampler1D(inputInfo: InputInfo): string {\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const texShape = inputInfo.shapeInfo.texShape;\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n const glsl = getGlslDifferences();\n\n return `\n vec4 ${funcName}(int index) {\n vec2 uv = packedUVfrom1D(\n ${packedTexShape[0]}, ${packedTexShape[1]}, index);\n return ${glsl.texture2D}(${texName}, uv);\n }\n `;\n}\n\nfunction getSampler1D(inputInfo: InputInfo): string {\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n\n if (inputInfo.shapeInfo.isUniform) {\n // Uniform arrays will be less than 65505 (no risk of float16 overflow).\n return `\n float ${funcName}(int index) {\n ${getUniformSampler(inputInfo)}\n }\n `;\n }\n\n const texShape = inputInfo.shapeInfo.texShape;\n const tNumR = texShape[0];\n const tNumC = texShape[1];\n\n if (tNumC === 1 && tNumR === 1) {\n return `\n float ${funcName}(int index) {\n return sampleTexture(${texName}, halfCR);\n }\n `;\n }\n const offset = getFlatOffsetUniformName(texName);\n if (tNumC === 1) {\n return `\n float ${funcName}(int index) {\n vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n if (tNumR === 1) {\n return `\n float ${funcName}(int index) {\n vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n return `\n float ${funcName}(int index) {\n vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});\n return sampleTexture(${texName}, uv);\n }\n `;\n}\n\nfunction getPackedSampler2D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const texShape = inputInfo.shapeInfo.texShape;\n\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n const glsl = getGlslDifferences();\n if (texShape != null && util.arraysEqual(shape, texShape)) {\n return `\n vec4 ${funcName}(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);\n\n return ${glsl.texture2D}(${texName}, uv);\n }\n `;\n }\n\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n const valuesPerRow = Math.ceil(shape[1] / 2);\n\n return `\n vec4 ${funcName}(int row, int col) {\n vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${\n packedTexShape[1]}, row, col);\n return ${glsl.texture2D}(${texName}, uv);\n }\n `;\n}\n\nfunction getSampler2D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const texShape = inputInfo.shapeInfo.texShape;\n\n if (texShape != null && util.arraysEqual(shape, texShape)) {\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n return `\n float ${funcName}(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n const {newShape, keptDims} = util.squeezeShape(shape);\n const squeezedShape = newShape;\n if (squeezedShape.length < shape.length) {\n const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);\n const params = ['row', 'col'];\n return `\n ${getSamplerFromInInfo(newInputInfo)}\n float ${funcName}(int row, int col) {\n return ${funcName}(${getSqueezedParams(params, keptDims)});\n }\n `;\n }\n\n if (inputInfo.shapeInfo.isUniform) {\n // Uniform arrays will be less than 65505 (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));\n ${getUniformSampler(inputInfo)}\n }\n `;\n }\n\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n const offset = getFlatOffsetUniformName(texName);\n if (texNumC === 1) {\n // index is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col) {\n float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n if (texNumR === 1) {\n // index is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col) {\n float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));\n vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n return `\n float ${funcName}(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ${shape[1]} + col + ${offset};\n vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);\n return sampleTexture(${texName}, uv);\n }\n`;\n}\n\nfunction getPackedSampler3D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const texShape = inputInfo.shapeInfo.texShape;\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n\n if (shape[0] === 1) {\n const squeezedShape = shape.slice(1);\n const keptDims = [1, 2];\n const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);\n const params = ['b', 'row', 'col'];\n return `\n ${getPackedSamplerFromInInfo(newInputInfo)}\n vec4 ${funcName}(int b, int row, int col) {\n return ${funcName}(${getSqueezedParams(params, keptDims)});\n }\n `;\n }\n\n const texNumR = packedTexShape[0];\n const texNumC = packedTexShape[1];\n\n const valuesPerRow = Math.ceil(shape[2] / 2);\n const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);\n const glsl = getGlslDifferences();\n\n return `\n vec4 ${funcName}(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n ${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);\n return ${glsl.texture2D}(${texName}, uv);\n }\n `;\n}\n\nfunction getSampler3D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const stride0 = shape[1] * shape[2];\n const stride1 = shape[2];\n\n const {newShape, keptDims} = util.squeezeShape(shape);\n const squeezedShape = newShape;\n if (squeezedShape.length < shape.length) {\n const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);\n const params = ['row', 'col', 'depth'];\n return `\n ${getSamplerFromInInfo(newInputInfo)}\n float ${funcName}(int row, int col, int depth) {\n return ${funcName}(${getSqueezedParams(params, keptDims)});\n }\n `;\n }\n\n if (inputInfo.shapeInfo.isUniform) {\n // Uniform arrays will be less than 65505 (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(${stride0}, ${stride1}, 1)));\n ${getUniformSampler(inputInfo)}\n }\n `;\n }\n\n const texShape = inputInfo.shapeInfo.texShape;\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n const flatOffset = inputInfo.shapeInfo.flatOffset;\n if (texNumC === stride0 && flatOffset == null) {\n // texC is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(${stride1}, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n if (texNumC === stride1 && flatOffset == null) {\n // texR is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n const offset = getFlatOffsetUniformName(texName);\n return `\n float ${funcName}(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ${stride0} + col * ${stride1} + depth + ${offset};\n vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);\n return sampleTexture(${texName}, uv);\n }\n `;\n}\n\nfunction getPackedSamplerND(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const rank = shape.length;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const texShape = inputInfo.shapeInfo.texShape;\n const packedTexShape =\n [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];\n const texNumR = packedTexShape[0];\n const texNumC = packedTexShape[1];\n\n const valuesPerRow = Math.ceil(shape[rank - 1] / 2);\n let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);\n let params = `int b, int row, int col`;\n let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;\n for (let b = 2; b < rank - 1; b++) {\n params = `int b${b}, ` + params;\n texelsInBatch *= shape[rank - b - 1];\n index = `b${b} * ${texelsInBatch} + ` + index;\n }\n const glsl = getGlslDifferences();\n return `\n vec4 ${funcName}(${params}) {\n int index = ${index};\n int texR = index / ${texNumC};\n int texC = index - texR * ${texNumC};\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});\n return ${glsl.texture2D}(${texName}, uv);\n }\n `;\n}\n\nfunction getSampler4D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const stride2 = shape[3];\n const stride1 = shape[2] * stride2;\n const stride0 = shape[1] * stride1;\n\n const {newShape, keptDims} = util.squeezeShape(shape);\n if (newShape.length < shape.length) {\n const newInputInfo = squeezeInputInfo(inputInfo, newShape);\n const params = ['row', 'col', 'depth', 'depth2'];\n return `\n ${getSamplerFromInInfo(newInputInfo)}\n float ${funcName}(int row, int col, int depth, int depth2) {\n return ${funcName}(${getSqueezedParams(params, keptDims)});\n }\n `;\n }\n\n if (inputInfo.shapeInfo.isUniform) {\n // Uniform arrays will be less than 65505 (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(${stride0}, ${stride1}, ${stride2}, 1)));\n ${getUniformSampler(inputInfo)}\n }\n `;\n }\n\n const flatOffset = inputInfo.shapeInfo.flatOffset;\n const texShape = inputInfo.shapeInfo.texShape;\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n\n if (texNumC === stride0 && flatOffset == null) {\n // texC is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(${stride1}, ${stride2}, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n if (texNumC === stride2 && flatOffset == null) {\n // texR is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(${shape[1] * shape[2]}, ${shape[2]}, 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n const offset = getFlatOffsetUniformName(texName);\n return `\n float ${funcName}(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ${stride0} + col * ${stride1} +\n depth * ${stride2} + depth2;\n vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});\n return sampleTexture(${texName}, uv);\n }\n `;\n}\n\nfunction getSampler5D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n const stride3 = shape[4];\n const stride2 = shape[3] * stride3;\n const stride1 = shape[2] * stride2;\n const stride0 = shape[1] * stride1;\n\n const {newShape, keptDims} = util.squeezeShape(shape);\n if (newShape.length < shape.length) {\n const newInputInfo = squeezeInputInfo(inputInfo, newShape);\n const params = ['row', 'col', 'depth', 'depth2', 'depth3'];\n return `\n ${getSamplerFromInInfo(newInputInfo)}\n float ${funcName}(int row, int col, int depth, int depth2, int depth3) {\n return ${funcName}(${getSqueezedParams(params, keptDims)});\n }\n `;\n }\n\n if (inputInfo.shapeInfo.isUniform) {\n // Uniform arrays will be less than 65505 (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +\n depth3;\n ${getUniformSampler(inputInfo)}\n }\n `;\n }\n\n const flatOffset = inputInfo.shapeInfo.flatOffset;\n const texShape = inputInfo.shapeInfo.texShape;\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n\n if (texNumC === stride0 && flatOffset == null) {\n // texC is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(${stride1}, ${stride2}, ${stride3}, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n if (texNumC === stride3 && flatOffset == null) {\n // texR is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(${shape[1] * shape[2] * shape[3]},\n ${shape[2] * shape[3]}, ${shape[3]}, 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n\n const offset = getFlatOffsetUniformName(texName);\n return `\n float ${funcName}(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +\n depth2 * ${stride3} + depth3 + ${offset};\n vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);\n return sampleTexture(${texName}, uv);\n }\n `;\n}\n\nfunction getSampler6D(inputInfo: InputInfo): string {\n const shape = inputInfo.shapeInfo.logicalShape;\n const texName = inputInfo.name;\n const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);\n\n const {newShape, keptDims} = util.squeezeShape(shape);\n if (newShape.length < shape.length) {\n const newInputInfo = squeezeInputInfo(inputInfo, newShape);\n const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];\n return `\n ${getSamplerFromInInfo(newInputInfo)}\n float ${funcName}(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return ${funcName}(${getSqueezedParams(params, keptDims)});\n }\n `;\n }\n\n const stride4 = shape[5];\n const stride3 = shape[4] * stride4;\n const stride2 = shape[3] * stride3;\n const stride1 = shape[2] * stride2;\n const stride0 = shape[1] * stride1;\n\n if (inputInfo.shapeInfo.isUniform) {\n // Uniform arrays will be less than 65505 (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +\n dot(\n vec2(depth3, depth4),\n vec2(${stride4}, 1)));\n ${getUniformSampler(inputInfo)}\n }\n `;\n }\n\n const flatOffset = inputInfo.shapeInfo.flatOffset;\n const texShape = inputInfo.shapeInfo.texShape;\n const texNumR = texShape[0];\n const texNumC = texShape[1];\n if (texNumC === stride0 && flatOffset == null) {\n // texC is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n if (texNumC === stride4 && flatOffset == null) {\n // texR is used directly as physical (no risk of float16 overflow).\n return `\n float ${funcName}(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(${shape[1] * shape[2] * shape[3] * shape[4]},\n ${shape[2] * shape[3] * shape[4]},\n ${shape[3] * shape[4]},\n ${shape[4]})) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${texNumC}.0, ${texNumR}.0);\n return sampleTexture(${texName}, uv);\n }\n `;\n }\n const offset = getFlatOffsetUniformName(texName);\n return `\n float ${funcName}(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +\n depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};\n vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);\n return sampleTexture(${texName}, uv);\n }\n `;\n}\n\nfunction getUniformSampler(inputInfo: InputInfo): string {\n const texName = inputInfo.name;\n const inSize = util.sizeFromShape(inputInfo.shapeInfo.logicalShape);\n\n if (inSize < 2) {\n return `return ${texName};`;\n }\n return `\n for (int i = 0; i < ${inSize}; i++) {\n if (i == index) {\n return ${texName}[i];\n }\n }\n `;\n}\n\nfunction getPackedSamplerAtOutputCoords(\n inputInfo: InputInfo, outShapeInfo: ShapeInfo) {\n const texName = inputInfo.name;\n const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);\n const funcName = 'get' + texFuncSnippet + 'AtOutCoords';\n const inRank = inputInfo.shapeInfo.logicalShape.length;\n const outRank = outShapeInfo.logicalShape.length;\n\n const broadcastDims = getBroadcastDims(\n inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);\n\n const type = getCoordsDataType(outRank);\n const rankDiff = outRank - inRank;\n let coordsSnippet: string;\n const fields = ['x', 'y', 'z', 'w', 'u', 'v'];\n\n if (inRank === 0) {\n coordsSnippet = '';\n } else if (outRank < 2 && broadcastDims.length >= 1) {\n coordsSnippet = 'coords = 0;';\n } else {\n coordsSnippet =\n broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)\n .join('\\n');\n }\n let unpackedCoordsSnippet = '';\n if (outRank < 2 && inRank > 0) {\n unpackedCoordsSnippet = 'coords';\n } else {\n unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape\n .map((s, i) => `coords.${fields[i + rankDiff]}`)\n .join(', ');\n }\n\n let output = `return outputValue;`;\n const inSize = util.sizeFromShape(inputInfo.shapeInfo.logicalShape);\n const isInputScalar = inSize === 1;\n const outSize = util.sizeFromShape(outShapeInfo.logicalShape);\n const isOutputScalar = outSize === 1;\n\n if (inRank === 1 && !isInputScalar && !isOutputScalar) {\n output = `\n return vec4(outputValue.xy, outputValue.xy);\n `;\n } else if (isInputScalar && !isOutputScalar) {\n if (outRank === 1) {\n output = `\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n `;\n } else {\n output = `\n return vec4(outputValue.x);\n `;\n }\n } else if (broadcastDims.length) {\n const rows = inRank - 2;\n const cols = inRank - 1;\n\n if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {\n output = `return vec4(outputValue.x);`;\n } else if (broadcastDims.indexOf(rows) > -1) {\n output = `return vec4(outputValue.x, outputValue.y, ` +\n `outputValue.x, outputValue.y);`;\n } else if (broadcastDims.indexOf(cols) > -1) {\n output = `return vec4(outputValue.xx, outputValue.zz);`;\n }\n }\n\n return `\n vec4 ${funcName}() {\n ${type} coords = getOutputCoords();\n ${coordsSnippet}\n vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});\n ${output}\n }\n `;\n}\n\nfunction getSamplerAtOutputCoords(\n inputInfo: InputInfo, outShapeInfo: ShapeInfo) {\n const texName = inputInfo.name;\n const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);\n const funcName = 'get' + texFuncSnippet + 'AtOutCoords';\n const outTexShape = outShapeInfo.texShape;\n const inTexShape = inputInfo.shapeInfo.texShape;\n const inRank = inputInfo.shapeInfo.logicalShape.length;\n const outRank = outShapeInfo.logicalShape.length;\n\n if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&\n inputInfo.shapeInfo.flatOffset == null &&\n util.arraysEqual(inTexShape, outTexShape)) {\n return `\n float ${funcName}() {\n return sampleTexture(${texName}, resultUV);\n }\n `;\n }\n\n const type = getCoordsDataType(outRank);\n const broadcastDims = getBroadcastDims(\n inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);\n const rankDiff = outRank - inRank;\n let coordsSnippet: string;\n const fields = ['x', 'y', 'z', 'w', 'u', 'v'];\n\n if (inRank === 0) {\n coordsSnippet = '';\n } else if (outRank < 2 && broadcastDims.length >= 1) {\n coordsSnippet = 'coords = 0;';\n } else {\n coordsSnippet =\n broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)\n .join('\\n');\n }\n let unpackedCoordsSnippet = '';\n if (outRank < 2 && inRank > 0) {\n unpackedCoordsSnippet = 'coords';\n } else {\n unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape\n .map((s, i) => `coords.${fields[i + rankDiff]}`)\n .join(', ');\n }\n\n return `\n float ${funcName}() {\n ${type} coords = getOutputCoords();\n ${coordsSnippet}\n return get${texFuncSnippet}(${unpackedCoordsSnippet});\n }\n `;\n}\n\nexport function getCoordsDataType(rank: number): string {\n if (rank <= 1) {\n return 'int';\n } else if (rank === 2) {\n return 'ivec2';\n } else if (rank === 3) {\n return 'ivec3';\n } else if (rank === 4) {\n return 'ivec4';\n } else if (rank === 5) {\n return 'ivec5';\n } else if (rank === 6) {\n return 'ivec6';\n } else {\n throw Error(`GPU for rank ${rank} is not yet supported`);\n }\n}\n\n/** Returns a new input info (a copy) that has a squeezed logical shape. */\nfunction squeezeInputInfo(\n inInfo: InputInfo, squeezedShape: number[]): InputInfo {\n // Deep copy.\n const newInputInfo: InputInfo = JSON.parse(JSON.stringify(inInfo));\n newInputInfo.shapeInfo.logicalShape = squeezedShape;\n return newInputInfo;\n}\n\nfunction getSqueezedParams(params: string[], keptDims: number[]): string {\n return keptDims.map(d => params[d]).join(', ');\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {assert} from '../../util';\nimport {getChannels} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ArgMinMaxPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[];\n userCode: string;\n packedInputs = true;\n packedOutput = true;\n\n constructor(\n shape: number[], windowSize: number, op: 'max'|'min',\n firstPass: boolean) {\n assert(\n shape.length > 2,\n () => `Packed arg${\n op.charAt(0).toUpperCase() +\n op.slice(1)} supports only inputs with rank above 2.`);\n const inSize = shape[shape.length - 1];\n const outSize = Math.ceil(inSize / windowSize);\n this.outputShape = shape.slice(0, -1);\n if (outSize > 1) {\n this.outputShape.push(outSize);\n }\n if (!firstPass) {\n this.variableNames.push('bestIndicesA');\n }\n const outShape = this.outputShape;\n const rank = outShape.length;\n const dtype = getCoordsDataType(rank);\n const coords = getChannels('coords', rank);\n\n let sourceLocSetup;\n let sourceRank;\n if (outSize === 1) {\n sourceRank = rank + 1;\n const sourceLocDType = getCoordsDataType(sourceRank);\n sourceLocSetup = `\n ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);\n ++${coords[rank - 1]};\n ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);\n ++${coords[rank - 2]};\n ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);\n --${coords[rank - 1]};\n ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);\n --${coords[rank - 2]};`;\n } else {\n sourceRank = rank;\n sourceLocSetup = `\n ${dtype} sourceLocR = coords;\n ++${coords[rank - 1]};\n ${dtype} sourceLocG = coords;\n ++${coords[rank - 2]};\n ${dtype} sourceLocA = coords;\n --${coords[rank - 1]};\n ${dtype} sourceLocB = coords;\n --${coords[rank - 2]};`;\n }\n const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);\n const inChannel = '.' + channels[sourceRank - 1]; // e.g. \".b\" for rank 3.\n const intChannels = channels.map(x => 'int ' + x);\n const srcRCoords =\n getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');\n const srcGCoords =\n getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');\n const srcBCoords =\n getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');\n const srcACoords =\n getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');\n\n const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';\n const fetchCandidateIdx = firstPass ? '' : `\n inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),\n getBestIndicesAChannel(${srcGCoords.join()}),\n getBestIndicesAChannel(${srcBCoords.join()}),\n getBestIndicesAChannel(${srcACoords.join()})));`;\n\n const fetchValue = `vec4(\n getAChannel(${srcRCoords.join()}),\n hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,\n hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,\n hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;\n\n const getBestIndicesAChannelSnippet = firstPass ? '' : `\n float getBestIndicesAChannel(${intChannels.join()}) {\n return getChannel(getBestIndicesA(${channels.join()}),\n vec2(${channels.slice(-2).join()}));\n }`;\n\n this.userCode = `\n float getAChannel(${intChannels.join()}) {\n return getChannel(getA(${channels.join()}),\n vec2(${channels.slice(-2).join()}));\n }\n ${getBestIndicesAChannelSnippet}\n void main() {\n ${dtype} coords = getOutputCoords();\n bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};\n bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};\n ${sourceLocSetup}\n ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},\n sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = ${fetchValue};\n\n for (int i = 0; i < ${windowSize}; i++) {\n inIdx = srcIdx;\n ${fetchCandidateIdx}\n vec4 candidate = ${fetchValue};\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class AvgPool2DBackpropProgram implements GPGPUProgram {\n variableNames = ['dy'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv2DInfo) {\n this.outputShape = convInfo.inShape;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n\n const avgMultiplier = 1 / (filterHeight * filterWidth);\n\n this.userCode = `\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n const float avgMultiplier = float(${avgMultiplier});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ${effectiveFilterWidth};\n wC+= ${dilationWidth}) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n\nexport class AvgPool3DBackpropProgram implements GPGPUProgram {\n variableNames = ['dy'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv3DInfo) {\n this.outputShape = convInfo.inShape;\n const filterDepth = convInfo.filterDepth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n\n const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);\n\n this.userCode = `\n const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n const float avgMultiplier = float(${avgMultiplier});\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < ${effectiveFilterDepth};\n wD += ${dilationDepth}) {\n float dyD = float(dyDCorner + wD) / ${strideDepth}.0;\n\n if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ${effectiveFilterWidth};\n wC += ${dilationWidth}) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as broadcast_util from '../../ops/broadcast_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class BatchNormProgram implements GPGPUProgram {\n variableNames: string[];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(\n xShape: number[], meanShape: number[], varianceShape: number[],\n offsetShape: number[]|null, scaleShape: number[]|null,\n varianceEpsilon: number) {\n this.variableNames = ['x', 'mean', 'variance'];\n broadcast_util.assertAndGetBroadcastShape(xShape, meanShape);\n broadcast_util.assertAndGetBroadcastShape(xShape, varianceShape);\n\n let offsetSnippet = '0.0';\n if (offsetShape != null) {\n broadcast_util.assertAndGetBroadcastShape(xShape, offsetShape);\n this.variableNames.push('offset');\n offsetSnippet = 'getOffsetAtOutCoords()';\n }\n\n let scaleSnippet = '1.0';\n if (scaleShape != null) {\n broadcast_util.assertAndGetBroadcastShape(xShape, scaleShape);\n this.variableNames.push('scale');\n scaleSnippet = 'getScaleAtOutCoords()';\n }\n\n this.outputShape = xShape;\n this.userCode = `\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = ${offsetSnippet};\n float scale = ${scaleSnippet};\n float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as broadcast_util from '../../ops/broadcast_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class BatchNormPackedProgram implements GPGPUProgram {\n variableNames: string[];\n outputShape: number[];\n userCode: string;\n packedInputs = true;\n packedOutput = true;\n\n constructor(\n xShape: number[], meanShape: number[], varianceShape: number[],\n offsetShape: number[]|null, scaleShape: number[]|null,\n varianceEpsilon: number) {\n this.variableNames = ['x', 'mean', 'variance'];\n broadcast_util.assertAndGetBroadcastShape(xShape, meanShape);\n broadcast_util.assertAndGetBroadcastShape(xShape, varianceShape);\n\n let offsetSnippet = 'vec4(0.0)';\n if (offsetShape != null) {\n broadcast_util.assertAndGetBroadcastShape(xShape, offsetShape);\n this.variableNames.push('offset');\n offsetSnippet = 'getOffsetAtOutCoords()';\n }\n\n let scaleSnippet = 'vec4(1.0)';\n if (scaleShape != null) {\n broadcast_util.assertAndGetBroadcastShape(xShape, scaleShape);\n this.variableNames.push('scale');\n scaleSnippet = 'getScaleAtOutCoords()';\n }\n\n this.outputShape = xShape;\n this.userCode = `\n void main() {\n vec4 offset = ${offsetSnippet};\n vec4 scale = ${scaleSnippet};\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));\n\n setOutput((x - mean) * inv + offset);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as broadcast_util from '../../ops/broadcast_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\n// (Ar + Ai)(Br + Bi) =\n// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr\n// Yr = ArBr - AB\n// Yi = ArBi + AiBr\nexport const COMPLEX_MULTIPLY = {\n REAL: 'return areal * breal - aimag * bimag;',\n IMAG: 'return areal * bimag + aimag * breal;'\n};\n\nexport class BinaryOpComplexProgram implements GPGPUProgram {\n variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];\n userCode: string;\n outputShape: number[];\n\n constructor(op: string, aShape: number[], bShape: number[]) {\n this.outputShape =\n broadcast_util.assertAndGetBroadcastShape(aShape, bShape);\n\n this.userCode = `\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n ${op}\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as broadcast_util from '../../ops/broadcast_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nconst CHECK_NAN_SNIPPET = `\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n`;\n\nexport const ADD = 'return a + b;';\nexport const SUB = 'return a - b;';\nexport const MUL = 'return a * b;';\n\n// Without the equality check div produces 0.9999 for a = b, which when\n// floored can cause errors.\nexport const DIV = `\nif (a == b) {\n return 1.0;\n};\nreturn a / b;`;\n\n// We use native integer division to deal with floating point imprecision. Since\n// we implement floor division and glsl implements truncated division, we\n// correct for this by subtracting 1 from result when the result is negative and\n// there is a remainder.\nexport const INT_DIV = `\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n`;\n\nexport const POW = `\nif(a < 0.0 && floor(b) < b){\n return NAN;\n}\nif (b == 0.0) {\n return 1.0;\n}\nreturn (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n`;\nexport const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';\n\nexport const EQUAL = `return float(a == b);`;\n\nexport const NOT_EQUAL = `return float(a != b);`;\n\nexport const LESS = `return float(a < b);`;\n\nexport const LESS_EQUAL = `return float(a <= b);`;\n\nexport const GREATER = `return float(a > b);`;\n\nexport const GREATER_EQUAL = `return float(a >= b);`;\n\nexport const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`;\n\nexport const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`;\n\nexport const MAX = CHECK_NAN_SNIPPET + `\n return max(a, b);\n`;\nexport const MIN = CHECK_NAN_SNIPPET + `\n return min(a, b);\n`;\nexport const MOD = `if (b == 0.0) return NAN;\n return mod(a, b);`;\n\nexport const ATAN2 = CHECK_NAN_SNIPPET + `\n return atan(a, b);\n`;\n\nexport const ELU_DER = `return (b >= 1.0) ? a : a * (b + 1.0);`;\n\nexport const PRELU = `return (a < 0.) ? b * a : a;`;\n\nexport class BinaryOpProgram implements GPGPUProgram {\n variableNames = ['A', 'B'];\n outputShape: number[];\n userCode: string;\n\n constructor(op: string, aShape: number[], bShape: number[]) {\n this.outputShape =\n broadcast_util.assertAndGetBroadcastShape(aShape, bShape);\n this.userCode = `\n float binaryOperation(float a, float b) {\n ${op}\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as broadcast_util from '../../ops/broadcast_util';\nimport {sizeFromShape} from '../../util';\nimport {getChannels} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nconst CHECK_NAN_SNIPPET = `\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n`;\n\n// We do the same as in ./binaryop_gpu, with vec4 and ivec4.\n// On Linux, the vectorized implementation produces NaNs when a and b are 0.\nexport const DIV = `\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n`;\n\nexport const INT_DIV = `\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n`;\n\nexport const POW = `\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n ` +\n CHECK_NAN_SNIPPET + `\n return result;\n`;\n\nexport const PRELU = `\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n`;\n\nexport const ELU_DER = `\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n`;\n\nexport const ATAN2 = `\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n ` +\n CHECK_NAN_SNIPPET + `\n return result;\n`;\n\nexport const EQUAL = `\n return vec4(equal(a, b));\n`;\n\nexport const NOT_EQUAL = `\n return vec4(notEqual(a, b));\n`;\n\nexport const LESS = `\n return vec4(lessThan(a, b));\n`;\n\nexport const LESS_EQUAL = `\n return vec4(lessThanEqual(a, b));\n`;\n\nexport const GREATER = `\n return vec4(greaterThan(a, b));\n`;\n\nexport const GREATER_EQUAL = `\n return vec4(greaterThanEqual(a, b));\n`;\n\nexport const LOGICAL_AND = `\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n`;\n\nexport const LOGICAL_OR = `\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n`;\n\nexport const MAX = `\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n ` +\n CHECK_NAN_SNIPPET + `\n return result;\n`;\n\nexport const MIN = `\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n ` +\n CHECK_NAN_SNIPPET + `\n return result;\n`;\n\nexport const MOD = `\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n ` +\n CHECK_NAN_SNIPPET + `\n return result;\n`;\n\nexport class BinaryOpPackedProgram implements GPGPUProgram {\n variableNames = ['A', 'B'];\n outputShape: number[];\n userCode: string;\n supportsBroadcasting = true;\n packedInputs = true;\n packedOutput = true;\n\n constructor(\n op: string, aShape: number[], bShape: number[],\n checkOutOfBounds = false) {\n this.outputShape =\n broadcast_util.assertAndGetBroadcastShape(aShape, bShape);\n const rank = this.outputShape.length;\n let checkOutOfBoundsString = '';\n if (checkOutOfBounds) {\n if (rank === 0 || sizeFromShape(this.outputShape) === 1) {\n checkOutOfBoundsString = `\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n `;\n } else {\n const dtype = getCoordsDataType(rank);\n checkOutOfBoundsString = `\n ${dtype} coords = getOutputCoords();\n `;\n if (rank === 1) {\n checkOutOfBoundsString += `\n result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n `;\n } else {\n const channels = getChannels('coords', rank);\n checkOutOfBoundsString += `\n bool nextRowOutOfBounds =\n (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};\n bool nextColOutOfBounds =\n (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n `;\n }\n }\n }\n\n this.userCode = `\n vec4 binaryOperation(vec4 a, vec4 b) {\n ${op}\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n ${checkOutOfBoundsString}\n\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ClipProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n\n // Caching uniform locations for speed.\n minLoc: WebGLUniformLocation;\n maxLoc: WebGLUniformLocation;\n\n constructor(aShape: number[]) {\n this.outputShape = aShape;\n this.userCode = `\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n `;\n }\n\n getCustomSetupFunc(min: number, max: number) {\n return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {\n if (this.minLoc == null) {\n this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');\n this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');\n }\n gpgpu.gl.uniform1f(this.minLoc, min);\n gpgpu.gl.uniform1f(this.maxLoc, max);\n };\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ClipPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n packedInputs = true;\n packedOutput = true;\n userCode: string;\n outputShape: number[];\n\n // Caching uniform locations for speed.\n minLoc: WebGLUniformLocation;\n maxLoc: WebGLUniformLocation;\n\n constructor(aShape: number[]) {\n this.outputShape = aShape;\n this.userCode = `\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n `;\n }\n\n getCustomSetupFunc(min: number, max: number) {\n return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {\n if (this.minLoc == null) {\n this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');\n this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');\n }\n gpgpu.gl.uniform1f(this.minLoc, min);\n gpgpu.gl.uniform1f(this.maxLoc, max);\n };\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ComplexAbsProgram implements GPGPUProgram {\n variableNames = ['real', 'imag'];\n userCode: string;\n outputShape: number[];\n\n constructor(shape: number[]) {\n this.outputShape = shape;\n this.userCode = `\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as concat_util from '../../ops/concat_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ConcatProgram implements GPGPUProgram {\n variableNames: string[];\n outputShape: number[] = [];\n userCode: string;\n\n // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().\n constructor(shapes: Array<[number, number]>) {\n this.outputShape = concat_util.computeOutShape(shapes, 1 /* axis */);\n this.variableNames = shapes.map((_, i) => `T${i}`);\n\n const offsets: number[] = new Array(shapes.length - 1);\n offsets[0] = shapes[0][1];\n for (let i = 1; i < offsets.length; i++) {\n offsets[i] = offsets[i - 1] + shapes[i][1];\n }\n\n const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];\n for (let i = 1; i < offsets.length; i++) {\n const shift = offsets[i - 1];\n snippets.push(\n `else if (yC < ${offsets[i]}) ` +\n `setOutput(getT${i}(yR, yC-${shift}));`);\n }\n const lastIndex = offsets.length;\n const lastShift = offsets[offsets.length - 1];\n snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);\n\n this.userCode = `\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n ${snippets.join('\\n ')}\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as concat_util from '../../ops/concat_util';\nimport {getChannels} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ConcatPackedProgram implements GPGPUProgram {\n variableNames: string[];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[] = [];\n userCode: string;\n\n constructor(shapes: number[][], axis: number) {\n this.outputShape = concat_util.computeOutShape(shapes, axis);\n const shape = this.outputShape;\n const rank = shape.length;\n const dtype = getCoordsDataType(rank);\n const coords = getChannels('coords', rank);\n const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);\n this.variableNames = shapes.map((_, i) => `T${i}`);\n\n const offsets: number[] = new Array(shapes.length - 1);\n offsets[0] = shapes[0][axis];\n for (let i = 1; i < offsets.length; i++) {\n offsets[i] = offsets[i - 1] + shapes[i][axis];\n }\n\n const channel = channels[axis];\n const lastChannels = channels.slice(-2);\n const allChannels = channels.join();\n\n let getValueSnippet = `if (${channel} < ${offsets[0]}) {\n return getChannel(\n getT0(${allChannels}), vec2(${lastChannels.join()}));\n }`;\n for (let i = 1; i < offsets.length; i++) {\n const shift = offsets[i - 1];\n // Note: the >= comparison below may seem unnecessary given the check\n // above but is needed to workaround branch execution issues on some\n // devices. It makes all the conditions exclusive without relying on\n // execution order.\n getValueSnippet += `\n if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {\n return getChannel(\n getT${i}(${shiftedChannels(channels, channel, shift)}),\n vec2(${shiftedChannels(lastChannels, channel, shift)}));\n }`;\n }\n const lastIndex = offsets.length;\n const shift = offsets[offsets.length - 1];\n getValueSnippet += `\n return getChannel(\n getT${lastIndex}(${shiftedChannels(channels, channel, shift)}),\n vec2(${shiftedChannels(lastChannels, channel, shift)}));`;\n\n this.userCode = `\n float getValue(${channels.map(x => 'int ' + x)}) {\n ${getValueSnippet}\n }\n\n void main() {\n ${dtype} coords = getOutputCoords();\n vec4 result = vec4(getValue(${coords}), 0., 0., 0.);\n\n ${coords[rank - 1]} = ${coords[rank - 1]} + 1;\n if (${coords[rank - 1]} < ${shape[rank - 1]}) {\n result.g = getValue(${coords});\n }\n\n ${coords[rank - 2]} = ${coords[rank - 2]} + 1;\n if (${coords[rank - 2]} < ${shape[rank - 2]}) {\n result.a = getValue(${coords});\n }\n\n ${coords[rank - 1]} = ${coords[rank - 1]} - 1;\n if (${coords[rank - 2]} < ${shape[rank - 2]} &&\n ${coords[rank - 1]} < ${shape[rank - 1]}) {\n result.b = getValue(${coords});\n }\n setOutput(result);\n }\n `;\n }\n}\n\n/**\n * Return an expression for coordinates into a vector where a given channel\n * will be offset by [shift].\n *\n * @param channels the channels to consider\n * @param channel the channel we want shifted\n * @param shift the amount to subtract from the channel.\n *\n * @returns a string of the form 'x, y-[shift], z' where any one channel can\n * have the shift applied.\n */\nfunction shiftedChannels(channels: string[], channel: string, shift: number) {\n const channelIdx = channels.indexOf(channel);\n const res = channels.map((c, idx) => {\n if (idx === channelIdx) {\n return `${c} - ${shift}`;\n } else {\n return c;\n }\n });\n return res.join();\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class Conv2DDerFilterProgram implements GPGPUProgram {\n variableNames = ['x', 'dy'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv2DInfo) {\n this.outputShape = convInfo.filterShape;\n\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < ${convInfo.batchSize}; b++) {\n for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {\n int xR = wR + yR * ${strideHeight} - ${padTop};\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {\n int xC = wC + yC * ${strideWidth} - ${padLeft};\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n if (${isChannelsLast}) {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n\nexport class Conv2DDerInputProgram implements GPGPUProgram {\n variableNames = ['dy', 'W'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv2DInfo) {\n this.outputShape = convInfo.inShape;\n\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n\n const padTop = filterHeight - 1 - convInfo.padInfo.top;\n const padLeft = filterWidth - 1 - convInfo.padInfo.left;\n\n const rowDim = isChannelsLast ? 1 : 2;\n const colDim = isChannelsLast ? 2 : 3;\n const channelDim = isChannelsLast ? 3 : 1;\n\n this.userCode = `\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[${channelDim}];\n\n ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ${filterHeight}; wR++) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ${filterHeight} - 1 - wR;\n\n for (int wC = 0; wC < ${filterWidth}; wC++) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ${filterWidth} - 1 - wC;\n\n for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {\n\n if (${isChannelsLast}) {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n\nexport class Conv3DDerFilterProgram implements GPGPUProgram {\n variableNames = ['x', 'dy'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv3DInfo) {\n this.outputShape = convInfo.filterShape;\n\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const padFront = convInfo.padInfo.front;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n\n this.userCode = `\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < ${convInfo.batchSize}; b++) {\n for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {\n int xF = wF + yF * ${strideDepth} - ${padFront};\n\n if (xF < 0 || xF >= ${convInfo.inDepth}) {\n continue;\n }\n\n for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {\n int xR = wR + yR * ${strideHeight} - ${padTop};\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {\n int xC = wC + yC * ${strideWidth} - ${padLeft};\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n\nexport class Conv3DDerInputProgram implements GPGPUProgram {\n variableNames = ['dy', 'W'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv3DInfo) {\n this.outputShape = convInfo.inShape;\n\n const filterDepth = convInfo.filterDepth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n\n const padFront = filterDepth - 1 - convInfo.padInfo.front;\n const padTop = filterHeight - 1 - convInfo.padInfo.top;\n const padLeft = filterWidth - 1 - convInfo.padInfo.left;\n\n this.userCode = `\n const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < ${filterDepth}; wF++) {\n float dyF = float(dyFCorner + wF) / ${strideDepth}.0;\n\n if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = ${filterDepth} - 1 - wF;\n\n for (int wR = 0; wR < ${filterHeight}; wR++) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ${filterHeight} - 1 - wR;\n\n for (int wC = 0; wC < ${filterWidth}; wC++) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ${filterWidth} - 1 - wC;\n\n for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo} from '../../ops/conv_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class DepthwiseConv2DDerFilterProgram implements GPGPUProgram {\n variableNames = ['x', 'dy'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv2DInfo) {\n this.outputShape = convInfo.filterShape;\n\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n const channelMul = convInfo.outChannels / convInfo.inChannels;\n\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * ${channelMul} + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < ${convInfo.batchSize}; b++) {\n for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {\n int xR = wR + yR * ${strideHeight} - ${padTop};\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {\n int xC = wC + yC * ${strideWidth} - ${padLeft};\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n\nexport class DepthwiseConv2DDerInputProgram implements GPGPUProgram {\n variableNames = ['dy', 'W'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv2DInfo) {\n this.outputShape = convInfo.inShape;\n\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n\n const padTop = filterHeight - 1 - convInfo.padInfo.top;\n const padLeft = filterWidth - 1 - convInfo.padInfo.left;\n const channelMul = convInfo.outChannels / convInfo.inChannels;\n\n this.userCode = `\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < ${filterHeight}; wR++) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ${filterHeight} - 1 - wR;\n\n for (int wC = 0; wC < ${filterWidth}; wC++) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ${filterWidth} - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < ${channelMul}; dm++) {\n int d2 = d1 * ${channelMul} + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class Conv2DProgram implements GPGPUProgram {\n variableNames = ['x', 'W'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n convInfo: Conv2DInfo, addBias = false, activation: string = null,\n hasPreluActivationWeights = false) {\n this.outputShape = convInfo.outShape;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n\n const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;\n const inputDepthVec4Remainder = convInfo.inChannels % 4;\n const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n\n const rowDim = isChannelsLast ? 1 : 2;\n const colDim = isChannelsLast ? 2 : 3;\n const channelDim = isChannelsLast ? 3 : 1;\n\n let activationSnippet = '', applyActivationSnippet = '';\n if (activation) {\n if (hasPreluActivationWeights) {\n activationSnippet = `float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n ${activation}\n }`;\n } else {\n activationSnippet = `\n float activation(float x) {\n ${activation}\n }\n `;\n }\n\n applyActivationSnippet = `result = activation(result);`;\n }\n\n const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';\n if (addBias) {\n this.variableNames.push('bias');\n }\n\n if (hasPreluActivationWeights) {\n this.variableNames.push('preluActivationWeights');\n }\n\n this.userCode = `\n ${activationSnippet}\n\n const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[${channelDim}];\n\n ivec2 xRCCorner =\n ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ${filterHeight}; wR++) {\n int xR = xRCorner + wR * ${dilationHeight};\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int wC = 0; wC < ${filterWidth}; wC++) {\n int xC = xCCorner + wC * ${dilationWidth};\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (${isChannelsLast}) {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (${inputDepthVec4Remainder === 1}) {\n\n if (${isChannelsLast}) {\n dotProd +=\n getX(batch, xR, xC, ${inputDepthNearestVec4}) *\n getW(wR, wC, ${inputDepthNearestVec4}, d2);\n } else {\n dotProd +=\n getX(batch, ${inputDepthNearestVec4}, xR, xC) *\n getW(wR, wC, ${inputDepthNearestVec4}, d2);\n }\n\n } else if (${inputDepthVec4Remainder === 2}) {\n vec2 wValues = vec2(\n getW(wR, wC, ${inputDepthNearestVec4}, d2),\n getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)\n );\n\n if (${isChannelsLast}) {\n vec2 xValues = vec2(\n getX(batch, xR, xC, ${inputDepthNearestVec4}),\n getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, ${inputDepthNearestVec4}, xR, xC),\n getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (${inputDepthVec4Remainder === 3}) {\n vec3 wValues = vec3(\n getW(wR, wC, ${inputDepthNearestVec4}, d2),\n getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),\n getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)\n );\n\n if (${isChannelsLast}) {\n vec3 xValues = vec3(\n getX(batch, xR, xC, ${inputDepthNearestVec4}),\n getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),\n getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, ${inputDepthNearestVec4}, xR, xC),\n getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),\n getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n ${addBiasSnippet}\n ${applyActivationSnippet}\n setOutput(result);\n }\n `;\n }\n}\n\nexport class Conv3DProgram implements GPGPUProgram {\n variableNames = ['x', 'W'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv3DInfo) {\n this.outputShape = convInfo.outShape;\n const padFront = convInfo.padInfo.front;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const filterDepth = convInfo.filterDepth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n\n const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;\n const inputDepthVec4Remainder = convInfo.inChannels % 4;\n\n this.userCode = `\n const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${\n strideWidth});\n const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < ${filterDepth}; wF++) {\n int xF = xFCorner + wF * ${dilationDepth};\n\n if (xF < 0 || xF >= ${convInfo.inDepth}) {\n continue;\n }\n\n for (int wR = 0; wR < ${filterHeight}; wR++) {\n int xR = xRCorner + wR * ${dilationHeight};\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int wC = 0; wC < ${filterWidth}; wC++) {\n int xC = xCCorner + wC * ${dilationWidth};\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (${inputDepthVec4Remainder === 1}) {\n dotProd +=\n getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *\n getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);\n } else if (${inputDepthVec4Remainder === 2}) {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),\n getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),\n getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (${inputDepthVec4Remainder === 3}) {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),\n getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),\n getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),\n getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),\n getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo} from '../../ops/conv_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class DepthwiseConv2DProgram implements GPGPUProgram {\n variableNames = ['x', 'W'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n convInfo: Conv2DInfo, addBias = false, activation: string = null,\n hasPreluActivation = false) {\n this.outputShape = convInfo.outShape;\n\n const xNumRows = convInfo.inHeight;\n const xNumCols = convInfo.inWidth;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const channelMul = convInfo.outChannels / convInfo.inChannels;\n\n let activationSnippet = '', applyActivationSnippet = '';\n if (activation) {\n if (hasPreluActivation) {\n activationSnippet = `float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n ${activation}\n }`;\n } else {\n activationSnippet = `\n float activation(float x) {\n ${activation}\n }\n `;\n }\n\n applyActivationSnippet = `result = activation(result);`;\n }\n\n const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';\n if (addBias) {\n this.variableNames.push('bias');\n }\n\n if (hasPreluActivation) {\n this.variableNames.push('preluActivationWeights');\n }\n\n this.userCode = `\n ${activationSnippet}\n\n const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / ${channelMul};\n int q = d2 - d1 * ${channelMul};\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < ${filterHeight}; wR++) {\n int xR = xRCorner + wR * ${dilationHeight};\n\n if (xR < 0 || xR >= ${xNumRows}) {\n continue;\n }\n\n for (int wC = 0; wC < ${filterWidth}; wC++) {\n int xC = xCCorner + wC * ${dilationWidth};\n\n if (xC < 0 || xC >= ${xNumCols}) {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n ${addBiasSnippet}\n ${applyActivationSnippet}\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo} from '../../ops/conv_util';\nimport * as util from '../../util';\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class DepthwiseConvPacked2DProgram implements GPGPUProgram {\n variableNames = ['x', 'W'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[];\n userCode: string;\n\n constructor(\n convInfo: Conv2DInfo, addBias = false, activation: string = null,\n hasPreluActivation = false) {\n this.outputShape = convInfo.outShape;\n\n const xNumRows = convInfo.inHeight;\n const xNumCols = convInfo.inWidth;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const texelsAcross = filterWidth;\n\n let mainLoop = `int xR; int xC; int xCOffset;`;\n\n for (let r = 0; r < filterHeight; r++) {\n for (let c = 0; c < filterWidth; c++) {\n mainLoop += `\n vec4 xTexelR${r}C${c * 2} = vec4(0.);\n vec4 wR${r}C${c} = vec4(0.);\n vec4 xR${r}C${c} = vec4(0.);`;\n }\n }\n\n /**\n * This vectorized implementation works by gathering the values needed for\n * each output channel's dot product into vec4's and then multiplying them\n * all together (this happens in the final double for-loop below). Most of\n * the main loop consists of constructing these vec4's with the minimum\n * number of texture2D calls, which means making use of all four returned\n * values from a texture2D call at once.\n */\n for (let r = 0; r < filterHeight; r++) {\n for (let texelC = 0; texelC < texelsAcross; texelC++) {\n const c = texelC * 2;\n\n mainLoop += `\n xR = xRCorner + ${r * dilationHeight};\n xC = xCCorner + ${c * dilationWidth};\n `;\n\n if (strideWidth === 1) {\n if (c < filterWidth) {\n // If padding is odd, the outer texels have to be composed.\n if (padLeft % 2 === 1) {\n // TODO: Ensure vec4 previous does not result in redundant sample,\n // and avoid setting xTexelRC's that exceed the boundary in the\n // first place rather than resetting them to vec4(0)).\n\n // To compute xCOffset:\n // - If padding is odd, we must add 1 to ensure we ask for an\n // even-numbered row.\n // - We subtract 2 to access the previous texel.\n\n mainLoop += `\n xCOffset = xC + 1;\n if(xR >= 0 && xR < ${xNumRows} && xCOffset >= 0 && xCOffset < ${\n xNumCols}) {\n xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= ${xNumCols}) {\n xTexelR${r}C${c}.zw = vec2(0.);\n }\n } else {\n xTexelR${r}C${c} = vec4(0.);\n }\n\n xCOffset = xC + 1 - 2;\n if(xR >= 0 && xR < ${xNumRows} && xCOffset >= 0 && xCOffset < ${\n xNumCols}) {\n vec4 previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= ${xNumCols}) {\n previous.zw = vec2(0.);\n }\n\n xR${r}C${c} = vec4(previous.zw, xTexelR${r}C${c}.xy);\n } else {\n xR${r}C${c} = vec4(0, 0, xTexelR${r}C${c}.xy);\n }\n `;\n } else {\n // Padding is even, so xRC corresponds to a single texel.\n mainLoop += `\n if(xR >= 0 && xR < ${xNumRows} && xC >= 0 && xC < ${xNumCols}) {\n xTexelR${r}C${c} = getX(batch, xR, xC, d1);\n } else {\n xTexelR${r}C${c} = vec4(0.);\n }\n\n xR${r}C${c} = xTexelR${r}C${c};\n `;\n }\n\n if (c + 1 < filterWidth) {\n // If dilation is even, the second entry should match the first\n // (either both are composed or both are single samples). But if\n // dilation is odd, then the second entry should be the opposite\n // of the first (if the first is composed, the second is a single\n // sample, and vice versa.)\n\n const nextTexelOffset = padLeft % 2 === 0 ?\n util.nearestLargerEven(dilationWidth) :\n dilationWidth;\n\n if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||\n (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {\n mainLoop += `\n xCOffset = xC + ${padLeft % 2} + ${nextTexelOffset};\n\n if(xR >= 0 && xR < ${xNumRows} &&\n xCOffset >= 0 && xCOffset < ${xNumCols}) {\n xTexelR${r}C${c + 2} = getX(batch, xR, xCOffset, d1);\n }\n `;\n\n // If dilation > 1 then the xRC's will not be able to share any\n // values, so each xRC will require two unique calls to getX.\n if (dilationWidth > 1) {\n mainLoop += `\n xCOffset -= 2;\n if(xR >= 0 && xR < ${xNumRows} &&\n xCOffset >= 0 && xCOffset < ${xNumCols}) {\n xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR${r}C${c} = vec4(0.);\n }\n `;\n }\n\n mainLoop += `\n xR${r}C${c + 1} = vec4(\n xTexelR${r}C${c}.zw, xTexelR${r}C${c + 2}.xy);\n `;\n } else {\n mainLoop += `\n xCOffset = xC + ${nextTexelOffset};\n\n if(xR >= 0 && xR < ${xNumRows} &&\n xCOffset >= 0 && xCOffset < ${xNumCols}) {\n xTexelR${r}C${c + 2} = getX(batch, xR, xCOffset, d1);\n }\n\n xR${r}C${c + 1} = xTexelR${r}C${c + 2};\n `;\n }\n }\n }\n } else { // stride > 1\n if (c < filterWidth) {\n mainLoop += `\n if(xR >= 0 && xR < ${xNumRows}) {\n `;\n\n // Depending on whether padLeft is even or odd, we want either the\n // xy or zw channels from X texels for xR${r}C${c}. If padLeft is\n // even, xR${r}C${c + 1} is simply the zw channels of texels we've\n // already sampled. But if padLeft is odd, xR${r}C{$c + 1}.zw will\n // need to come from the xy channels of a new texel, hence the `vec4\n // final` initialized below.\n if (padLeft % 2 === 1) {\n mainLoop += `\n xCOffset = xC + 1 - ${strideWidth};\n if(xCOffset >= 0 && xCOffset < ${xNumCols}) {\n xTexelR${r}C${c} = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR${r}C${c} = vec4(0.);\n }\n\n if(xC + 1 >= 0 && xC + 1 < ${xNumCols}) {\n xTexelR${r}C${c + 2} = getX(batch, xR, xC + 1, d1);\n } else {\n xTexelR${r}C${c + 2} = vec4(0.);\n }\n\n xR${r}C${c} = vec4(\n xTexelR${r}C${c}.zw, xTexelR${r}C${c + 2}.zw);\n `;\n\n if (c + 1 < filterWidth) {\n mainLoop += `\n vec4 final = vec4(0.);\n xCOffset = xC + 1 + ${strideWidth};\n if(xCOffset >= 0 && xCOffset < ${xNumCols}) {\n final = getX(batch, xR, xCOffset, d1);\n }\n xR${r}C${c + 1} = vec4(xTexelR${r}C${c + 2}.xy, final.xy);\n `;\n }\n } else {\n mainLoop += `\n if(xC >= 0 && xC < ${xNumCols}) {\n xTexelR${r}C${c} = getX(batch, xR, xC, d1);\n } else {\n xTexelR${r}C${c} = vec4(0.);\n }\n\n xCOffset = xC + ${strideWidth};\n if(xCOffset >= 0 && xCOffset < ${xNumCols}) {\n xTexelR${r}C${c + 2} = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR${r}C${c + 2} = vec4(0.);\n }\n\n xR${r}C${c} = vec4(\n xTexelR${r}C${c}.xy, xTexelR${r}C${c + 2}.xy);\n `;\n\n if (c + 1 < filterWidth) {\n mainLoop += `\n xR${r}C${c + 1} = vec4(\n xTexelR${r}C${c}.zw, xTexelR${r}C${c + 2}.zw);\n `;\n }\n }\n\n mainLoop += `}`;\n }\n }\n\n if (c < filterWidth) {\n mainLoop += `\n vec4 wTexelR${r}C${c} = getW(${r}, ${c}, d1, q);\n wR${r}C${c} = vec4(wTexelR${r}C${c}.xz, wTexelR${r}C${c}.xz);\n `;\n\n if (c + 1 < filterWidth) {\n mainLoop += `\n vec4 wTexelR${r}C${c + 1} = getW(${r}, ${c + 1}, d1, q);\n wR${r}C${c + 1} =\n vec4(wTexelR${r}C${c + 1}.xz, wTexelR${r}C${c + 1}.xz);`;\n }\n }\n }\n }\n\n for (let r = 0; r < filterHeight; r++) {\n for (let c = 0; c < filterWidth; c++) {\n mainLoop += `dotProd += xR${r}C${c} * wR${r}C${c};`;\n }\n }\n\n let activationSnippet = '', applyActivationSnippet = '';\n if (activation) {\n if (hasPreluActivation) {\n activationSnippet = `vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ${activation}\n }`;\n } else {\n activationSnippet = `vec4 activation(vec4 x) {\n ${activation}\n }`;\n }\n\n applyActivationSnippet = `result = activation(result);`;\n }\n\n const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';\n if (addBias) {\n this.variableNames.push('bias');\n }\n\n if (hasPreluActivation) {\n this.variableNames.push('preluActivationWeights');\n }\n\n this.userCode = `\n ${activationSnippet}\n\n const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2;\n int q = 0;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n vec4 dotProd = vec4(0.);\n\n ${mainLoop}\n\n vec4 result = dotProd;\n ${addBiasSnippet}\n ${applyActivationSnippet}\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class CropAndResizeProgram implements GPGPUProgram {\n variableNames = ['Image', 'Boxes', 'BoxInd'];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(\n imageShape: [number, number, number, number], boxShape: [number, number],\n cropSize: [number, number], method: 'bilinear'|'nearest',\n extrapolationValue: number) {\n const [batch, imageHeight, imageWidth, depth] = imageShape;\n const [numBoxes, ] = boxShape;\n const [cropHeight, cropWidth] = cropSize;\n this.outputShape = [numBoxes, cropHeight, cropWidth, depth];\n const methodId = method === 'bilinear' ? 1 : 0;\n\n const [inputHeightFloat, inputWidthFloat] =\n [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`];\n\n const [heightRatio, heightScale, inY] = cropHeight > 1 ?\n [\n `${(imageHeight - 1) / (cropHeight - 1)}`,\n '(y2-y1) * height_ratio',\n `y1*${inputHeightFloat} + float(y)*(height_scale)`,\n ] :\n [\n '0.0',\n '0.0',\n `0.5 * (y1+y2) * ${inputHeightFloat}`,\n ];\n const [widthRatio, widthScale, inX] = cropWidth > 1 ?\n [\n `${(imageWidth - 1) / (cropWidth - 1)}`,\n '(x2-x1) * width_ratio',\n `x1*${inputWidthFloat} + float(x)*(width_scale)`,\n ] :\n [\n '0.0',\n '0.0',\n `0.5 * (x1+x2) * ${inputWidthFloat}`,\n ];\n\n // Reference implementation\n // tslint:disable-next-line:max-line-length\n // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc\n this.userCode = `\n const float height_ratio = float(${heightRatio});\n const float width_ratio = float(${widthRatio});\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= ${batch}) {\n return;\n }\n\n float height_scale = ${heightScale};\n float width_scale = ${widthScale};\n\n float in_y = ${inY};\n if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {\n setOutput(float(${extrapolationValue}));\n return;\n }\n float in_x = ${inX};\n if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {\n setOutput(float(${extrapolationValue}));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(${methodId} == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class CumSumProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(shape: number[], exclusive: boolean, reverse: boolean) {\n this.outputShape = shape;\n const rank = shape.length;\n const finalDim = shape[shape.length - 1];\n const comparator = reverse ? '<' : '>';\n\n this.userCode = `\n int getIndex(int i) {\n ${reverse ? `return ${finalDim} -i - 1;` : 'return i;'}\n }\n\n void main() {\n ${getCoordsDataType(rank)} coords = getOutputCoords();\n int end = ${getFinalCoord(rank, 'coords')};\n float val = 0.0;\n for (int i = ${finalDim} - 1; i >= 0; i -= 1) {\n int idx = getIndex(i);\n if (idx ${comparator} end) {\n continue;\n }\n if (idx == end && ${exclusive}) {\n continue;\n }\n ${getFinalCoord(rank, 'coords')} = idx;\n val += getX(${getCoords(rank, 'coords')});\n }\n setOutput(val);\n }\n `;\n }\n}\n\nfunction getCoords(rank: number, name: string): string {\n if (rank === 1) {\n return `${name}`;\n } else if (rank === 2) {\n return `${name}.x, ${name}.y`;\n } else if (rank === 3) {\n return `${name}.x, ${name}.y, ${name}.z`;\n } else if (rank === 4) {\n return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;\n } else {\n throw Error(`Cumulative sum for rank ${rank} is not yet supported`);\n }\n}\n\nfunction getFinalCoord(rank: number, name: string): string {\n if (rank === 1) {\n return `${name}`;\n } else if (rank === 2) {\n return `${name}.y`;\n } else if (rank === 3) {\n return `${name}.z`;\n } else if (rank === 4) {\n return `${name}.w`;\n } else {\n throw Error(`Cumulative sum for rank ${rank} is not yet supported`);\n }\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\nimport * as shader_util from './shader_compiler_util';\nimport {getDenseTexShape, PackingScheme} from './tex_util';\n\nexport class DecodeMatrixProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: [number, number, number];\n packedInputs = false;\n packedOutput = true;\n outPackingScheme = PackingScheme.DENSE;\n\n constructor(outputShape: [number, number, number]) {\n const texShape = getDenseTexShape(outputShape);\n const glsl = getGlslDifferences();\n this.outputShape = outputShape;\n\n this.userCode = `\n ivec3 outCoordsFromFlatIndex(int index) {\n ${\n shader_util.getLogicalCoordinatesFromFlatIndex(\n ['r', 'c', 'd'], outputShape)}\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n ${glsl.output} = result;\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\nimport * as shader_util from './shader_compiler_util';\nimport {getDenseTexShape, PackingScheme} from './tex_util';\n\nexport class DecodeMatrixPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n packedInputs = true;\n packedOutput = true;\n outputShape: [number, number, number];\n outPackingScheme = PackingScheme.DENSE;\n\n constructor(outputShape: [number, number, number]) {\n const texShape = getDenseTexShape(outputShape);\n const glsl = getGlslDifferences();\n this.outputShape = outputShape;\n\n this.userCode = `\n ivec3 outCoordsFromFlatIndex(int index) {\n ${\n shader_util.getLogicalCoordinatesFromFlatIndex(\n ['r', 'c', 'd'], outputShape)}\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(${texShape[0]}, ${texShape[1]}));\n int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n ${glsl.output} = result;\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class DepthToSpaceProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[] = [];\n userCode: string;\n blockSize: number;\n dataFormat: string;\n\n constructor(\n outputShape: number[], blockSize: number, dataFormat: 'NHWC'|'NCHW') {\n this.outputShape = outputShape;\n this.blockSize = blockSize;\n this.dataFormat = dataFormat;\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = ${this.getHeightCoordString()};\n int w = ${this.getWidthCoordString()};\n int d = ${this.getDepthCoordString()};\n\n int in_h = h / ${blockSize};\n int offset_h = imod(h, ${blockSize});\n int in_w = w / ${blockSize};\n int offset_w = imod(w, ${blockSize});\n int offset_d = (offset_h * ${blockSize} + offset_w) *\n ${this.getOutputDepthSize()};\n int in_d = d + offset_d;\n\n float result = ${this.getInputSamplingString()};\n setOutput(result);\n }\n `;\n }\n\n private getHeightCoordString(): string {\n if (this.dataFormat === 'NHWC') {\n return `coords[1]`;\n } else {\n return `coords[2]`;\n }\n }\n\n private getWidthCoordString(): string {\n if (this.dataFormat === 'NHWC') {\n return `coords[2]`;\n } else {\n return `coords[3]`;\n }\n }\n\n private getDepthCoordString(): string {\n if (this.dataFormat === 'NHWC') {\n return `coords[3]`;\n } else {\n return `coords[1]`;\n }\n }\n\n private getOutputDepthSize(): number {\n if (this.dataFormat === 'NHWC') {\n return this.outputShape[3];\n } else {\n return this.outputShape[1];\n }\n }\n\n private getInputSamplingString(): string {\n if (this.dataFormat === 'NHWC') {\n return `getX(b, in_h, in_w, in_d)`;\n } else {\n return `getX(b, in_d, in_h, in_w)`;\n }\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class DiagProgram implements GPGPUProgram {\n variableNames = ['X'];\n outputShape: number[];\n userCode: string;\n\n constructor(size: number) {\n this.outputShape = [size, size];\n this.userCode = `\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\nimport {ENCODE_FLOAT_SNIPPET} from './shader_compiler_util';\nimport {TextureUsage} from './tex_util';\n\nexport class EncodeFloatProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n outTexUsage = TextureUsage.DOWNLOAD;\n\n constructor(outputShape: number[]) {\n const glsl = getGlslDifferences();\n this.outputShape = outputShape;\n this.userCode = `\n ${ENCODE_FLOAT_SNIPPET}\n\n void main() {\n float x = getAAtOutCoords();\n ${glsl.output} = encode_float(x);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\nimport {ENCODE_FLOAT_SNIPPET} from './shader_compiler_util';\nimport {TextureUsage} from './tex_util';\n\nexport class EncodeFloatPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n packedInputs = true;\n packedOutput = false;\n outTexUsage = TextureUsage.DOWNLOAD;\n\n constructor(outputShape: [number, number, number]) {\n const glsl = getGlslDifferences();\n this.outputShape = outputShape;\n this.userCode = `\n ${ENCODE_FLOAT_SNIPPET}\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n ${glsl.output} = encode_float(x);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\nimport * as shader_util from './shader_compiler_util';\n\nexport class EncodeMatrixProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n\n constructor(\n outputShape: [number, number, number], texShape: [number, number],\n inputIsUnsignedByte = false) {\n const glsl = getGlslDifferences();\n const [height, width] = texShape;\n this.outputShape = outputShape;\n\n let output = `result`;\n if (inputIsUnsignedByte) {\n output = `floor(result * 255. + 0.5)`;\n }\n\n this.userCode = `\n ${shader_util.getFlatIndexFrom3D(outputShape)}\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n \n int r = flatIndex / ${width};\n int c = imod(flatIndex, ${width});\n vec2 uv = (vec2(c, r) + halfCR) / vec2(${width}.0, ${height}.0);\n vec4 values = ${glsl.texture2D}(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n ${glsl.output} = vec4(${output}, 0., 0., 0.);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\nimport * as shader_util from './shader_compiler_util';\n\n/*\nThis is how the shader encodes a tensor with shape = [2, 3, 5]\n(indices are [batch, row, col]).\n\n000|001 002|003 004|xxx 020|021 022|023 024|xxx\n------- ------- ------- ------- ------- -------\n010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx\n\n100|101 102|103 104|xxx 120|121 122|123 124|xxx\n------- ------- ------- ------- ------- -------\n110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx\n\nSingle texels contain only values from the same batch, and from adjacent rows\nand columns.\n */\n\nexport class EncodeMatrixPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n packedInputs = false;\n packedOutput = true;\n\n constructor(\n outputShape: [number, number, number], texShape: [number, number],\n inputIsUnsignedByte = false) {\n const glsl = getGlslDifferences();\n const [height, width] = texShape;\n this.outputShape = outputShape;\n\n let mainLoop = '';\n let output = 'result';\n if (inputIsUnsignedByte) {\n output = 'floor(result * 255. + 0.5)';\n }\n\n for (let row = 0; row <= 1; row++) {\n for (let col = 0; col <= 1; col++) {\n const channel = row * 2 + col;\n\n mainLoop += `\n localCoords = coords;\n if(localCoords[2] + ${col} < ${outputShape[2]}) {\n localCoords[2] += ${col};\n if(localCoords[1] + ${row} < ${outputShape[1]}) {\n localCoords[1] += ${row};\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n r = flatIndex / ${width};\n c = imod(flatIndex, ${width});\n uv = (vec2(c, r) + halfCR) / vec2(${width}.0, ${height}.0);\n values = ${glsl.texture2D}(A, uv);\n\n if(offset == 0) {\n result[${channel}] = values[0];\n } else if(offset == 1) {\n result[${channel}] = values[1];\n } else if(offset == 2) {\n result[${channel}] = values[2];\n } else {\n result[${channel}] = values[3];\n }\n }\n }\n `;\n }\n }\n\n this.userCode = `\n ${shader_util.getFlatIndexFrom3D(outputShape)}\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n ${mainLoop}\n\n ${glsl.output} = ${output};\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport const COMPLEX_FFT = {\n REAL: 'return real * expR - imag * expI;',\n IMAG: 'return real * expI + imag * expR;'\n};\n\nexport class FFTProgram implements GPGPUProgram {\n variableNames = ['real', 'imag'];\n outputShape: number[];\n userCode: string;\n\n constructor(op: string, inputShape: [number, number], inverse: boolean) {\n const innerDim = inputShape[1];\n this.outputShape = inputShape;\n\n const exponentMultiplierSnippet =\n inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`;\n const resultDenominator = inverse ? `${innerDim}.0` : '1.0';\n\n this.userCode = `\n const float exponentMultiplier = ${exponentMultiplierSnippet};\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n ${op}\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(${innerDim});\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < ${innerDim}; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class FillProgram implements GPGPUProgram {\n variableNames: string[];\n outputShape: number[] = [];\n userCode: string;\n\n valueLoc: WebGLUniformLocation;\n\n constructor(shape: number[], value: number) {\n this.variableNames = ['x'];\n this.outputShape = shape;\n\n this.userCode = `\n uniform float value;\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n `;\n }\n\n getCustomSetupFunc(value: number) {\n return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {\n if (this.valueLoc == null) {\n this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');\n }\n gpgpu.gl.uniform1f(this.valueLoc, value);\n };\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class GatherProgram implements GPGPUProgram {\n variableNames = ['A', 'indices'];\n outputShape: number[];\n userCode: string;\n rank: number;\n\n constructor(aShape: number[], indicesLength: number, axis: number) {\n const outputShape: number[] = aShape.slice();\n outputShape[axis] = indicesLength;\n this.outputShape = outputShape;\n this.rank = outputShape.length;\n const dtype = getCoordsDataType(this.rank);\n const sourceCoords = getSourceCoords(aShape, axis);\n\n this.userCode = `\n void main() {\n ${dtype} resRC = getOutputCoords();\n setOutput(getA(${sourceCoords}));\n }\n `;\n }\n}\n\nfunction getSourceCoords(aShape: number[], axis: number): string {\n const rank = aShape.length;\n if (rank > 4) {\n throw Error(`Gather for rank ${rank} is not yet supported`);\n }\n if (rank === 1) {\n return `int(getIndices(resRC))`;\n }\n\n const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];\n\n const sourceCoords = [];\n for (let i = 0; i < aShape.length; i++) {\n if (i === axis) {\n sourceCoords.push(`int(getIndices(${currentCoords[i]}))`);\n } else {\n sourceCoords.push(`${currentCoords[i]}`);\n }\n }\n return sourceCoords.join();\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class GatherNDProgram implements GPGPUProgram {\n variableNames = ['x', 'indices'];\n outputShape: number[];\n userCode: string;\n constructor(\n private sliceDim: number, private strides: number[], shape: number[]) {\n this.outputShape = shape;\n const stridesType = getCoordsDataType(strides.length);\n const dtype = getCoordsDataType(shape.length);\n const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';\n this.userCode = `\n ${stridesType} strides = ${stridesType}(${this.strides});\n void main() {\n ${dtype} coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < ${this.sliceDim}; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * ${strideString};\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {PixelData, TypedArray} from '../../types';\n\nimport {getGlslDifferences} from './glsl_version';\nimport * as tex_util from './tex_util';\nimport {TextureConfig} from './tex_util';\nimport * as webgl_util from './webgl_util';\n\nexport function createVertexShader(\n gl: WebGLRenderingContext, debug: boolean): WebGLShader {\n const glsl = getGlslDifferences();\n const vertexShaderSource = `${glsl.version}\n precision highp float;\n ${glsl.attribute} vec3 clipSpacePos;\n ${glsl.attribute} vec2 uv;\n ${glsl.varyingVs} vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }`;\n return webgl_util.createVertexShader(gl, debug, vertexShaderSource);\n}\n\nexport function createVertexBuffer(\n gl: WebGLRenderingContext, debug: boolean): WebGLBuffer {\n // [x y z u v] * [upper-left, lower-left, upper-right, lower-right]\n const vertexArray = new Float32Array(\n [-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);\n return webgl_util.createStaticVertexBuffer(gl, debug, vertexArray);\n}\n\nexport function createIndexBuffer(\n gl: WebGLRenderingContext, debug: boolean): WebGLBuffer {\n // OpenGL (and WebGL) have \"CCW == front\" winding\n const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);\n return webgl_util.createStaticIndexBuffer(gl, debug, triangleVertexIndices);\n}\n\nfunction createAndConfigureTexture(\n gl: WebGLRenderingContext, debug: boolean, width: number, height: number,\n internalFormat: number, textureFormat: number,\n textureType: number): WebGLTexture {\n webgl_util.validateTextureSize(width, height);\n const texture = webgl_util.createTexture(gl, debug);\n\n const tex2d = gl.TEXTURE_2D;\n webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(tex2d, texture));\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE));\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE));\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST));\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST));\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texImage2D(\n tex2d, 0, internalFormat, width, height, 0, textureFormat,\n textureType, null));\n webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null));\n return texture;\n}\n\nexport function createFloat32MatrixTexture(\n gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig): WebGLTexture {\n const [width, height] =\n tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns);\n return createAndConfigureTexture(\n gl, debug, width, height, textureConfig.internalFormatFloat,\n textureConfig.textureFormatFloat, gl.FLOAT);\n}\n\nexport function createFloat16MatrixTexture(\n gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig): WebGLTexture {\n const [width, height] =\n tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns);\n return createAndConfigureTexture(\n gl, debug, width, height, textureConfig.internalFormatHalfFloat,\n textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);\n}\n\nexport function createUnsignedBytesMatrixTexture(\n gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig): WebGLTexture {\n const [width, height] =\n tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns);\n return createAndConfigureTexture(\n gl, debug, width, height, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE);\n}\n\nexport function createPackedMatrixTexture(\n gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig): WebGLTexture {\n const [width, height] =\n tex_util.getPackedMatrixTextureShapeWidthHeight(rows, columns);\n return createAndConfigureTexture(\n gl, debug, width, height, textureConfig.internalFormatPackedFloat,\n gl.RGBA, gl.FLOAT);\n}\n\nexport function createFloat16PackedMatrixTexture(\n gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig): WebGLTexture {\n const [width, height] =\n tex_util.getPackedMatrixTextureShapeWidthHeight(rows, columns);\n return createAndConfigureTexture(\n gl, debug, width, height, textureConfig.internalFormatPackedHalfFloat,\n gl.RGBA, textureConfig.textureTypeHalfFloat);\n}\n\nexport function bindVertexProgramAttributeStreams(\n gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram,\n vertexBuffer: WebGLBuffer): boolean {\n const posOffset = 0; // x is the first buffer element\n const uvOffset = 3 * 4; // uv comes after [x y z]\n const stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float.\n webgl_util.callAndCheck(\n gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer));\n const success = webgl_util.bindVertexBufferToProgramAttribute(\n gl, debug, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);\n return success &&\n webgl_util.bindVertexBufferToProgramAttribute(\n gl, debug, program, 'uv', vertexBuffer, 2, stride, uvOffset);\n}\n\nexport function uploadDenseMatrixToTexture(\n gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture,\n width: number, height: number, data: TypedArray,\n textureConfig: TextureConfig) {\n webgl_util.callAndCheck(\n gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture));\n\n let dataForUpload: TypedArray, texelDataType: number, internalFormat: number;\n if (data instanceof Uint8Array) {\n dataForUpload = new Uint8Array(width * height * 4);\n texelDataType = gl.UNSIGNED_BYTE;\n internalFormat = gl.RGBA;\n } else {\n dataForUpload = new Float32Array(width * height * 4);\n texelDataType = gl.FLOAT;\n internalFormat = textureConfig.internalFormatPackedFloat;\n }\n\n dataForUpload.set(data);\n\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texImage2D(\n gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA,\n texelDataType, dataForUpload));\n\n webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null));\n}\n\nexport function uploadPixelDataToTexture(\n gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture,\n pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|\n HTMLVideoElement) {\n webgl_util.callAndCheck(\n gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture));\n if ((pixels as PixelData).data instanceof Uint8Array) {\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texImage2D(\n gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA,\n gl.UNSIGNED_BYTE, (pixels as PixelData).data));\n } else {\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.texImage2D(\n gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE,\n pixels as ImageData | HTMLImageElement | HTMLCanvasElement |\n HTMLVideoElement));\n }\n\n webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null));\n}\n\nexport function createBufferFromOutputTexture(\n gl2: WebGL2RenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig): WebGLBuffer {\n // Create and bind the buffer.\n const buffer = gl2.createBuffer();\n webgl_util.callAndCheck(\n gl2, debug, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer));\n\n // Initialize the buffer to the size of the texture in bytes.\n const bytesPerFloat = 4;\n const valuesPerTexel = 4;\n const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;\n\n webgl_util.callAndCheck(\n gl2, debug,\n () => gl2.bufferData(\n gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ));\n\n // Enqueue a command on the GPU command queue to copy of texture into the\n // buffer.\n webgl_util.callAndCheck(\n gl2, debug,\n () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0));\n\n webgl_util.callAndCheck(\n gl2, debug, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null));\n\n return buffer;\n}\n\nexport function downloadFloat32MatrixFromBuffer(\n gl: WebGLRenderingContext, buffer: WebGLBuffer,\n size: number): Float32Array {\n const gl2 = gl as WebGL2RenderingContext;\n\n const downloadTarget = new Float32Array(size);\n\n gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);\n gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);\n gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);\n\n return downloadTarget;\n}\n\nexport function downloadByteEncodedFloatMatrixFromOutputTexture(\n gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number,\n textureConfig: TextureConfig) {\n const [w, h] =\n tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns);\n\n const numChannels = 4;\n const downloadTarget = new Uint8Array(\n tex_util.getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));\n\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.readPixels(\n 0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE,\n downloadTarget));\n\n // By wrapping the buffer in a Float32Array, we use native browser IEEE 754\n // decoding of the 4 bytes that back each 32 bit float.\n return new Float32Array(downloadTarget.buffer);\n}\n\nexport function downloadPackedMatrixFromBuffer(\n gl: WebGLRenderingContext, buffer: WebGLBuffer, batch: number, rows: number,\n cols: number, physicalRows: number, physicalCols: number,\n textureConfig: TextureConfig): Float32Array {\n const gl2 = gl as WebGL2RenderingContext;\n\n const downloadTarget =\n new Float32Array(tex_util.getPackedRGBAArraySizeFromMatrixShape(\n physicalRows, physicalCols));\n\n gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);\n gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);\n gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);\n\n return downloadTarget;\n}\n\nexport function downloadMatrixFromPackedOutputTexture(\n gl: WebGLRenderingContext, debug: boolean, physicalRows: number,\n physicalCols: number): Float32Array {\n const packedRGBA = new Float32Array(physicalRows * physicalCols * 4);\n webgl_util.callAndCheck(\n gl, debug,\n () => gl.readPixels(\n 0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA));\n\n return packedRGBA;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../environment';\n\nimport {PixelData, TypedArray} from '../../types';\nimport * as util from '../../util';\n\nimport {getWebGLContext, setWebGLContext} from './canvas_util';\nimport * as gpgpu_util from './gpgpu_util';\nimport * as tex_util from './tex_util';\nimport {TextureConfig} from './tex_util';\nimport {WebGL1DisjointQueryTimerExtension, WebGL2DisjointQueryTimerExtension} from './webgl_types';\nimport * as webgl_util from './webgl_util';\n\nexport interface FenceContext {\n query: WebGLQuery|WebGLSync;\n isFencePassed(): boolean;\n}\n\nexport class GPGPUContext {\n gl: WebGLRenderingContext;\n textureFloatExtension: {};\n textureHalfFloatExtension: {};\n colorBufferFloatExtension: {};\n colorBufferHalfFloatExtension: {};\n disjointQueryTimerExtension: WebGL2DisjointQueryTimerExtension|\n WebGL1DisjointQueryTimerExtension;\n vertexBuffer: WebGLBuffer;\n indexBuffer: WebGLBuffer;\n framebuffer: WebGLFramebuffer;\n outputTexture: WebGLTexture|null = null;\n program: WebGLProgram|null = null;\n private disposed = false;\n private disjoint: boolean;\n private textureConfig: TextureConfig;\n\n constructor(gl?: WebGLRenderingContext) {\n const glVersion = env().getNumber('WEBGL_VERSION');\n if (gl != null) {\n this.gl = gl;\n setWebGLContext(glVersion, gl);\n } else {\n this.gl = getWebGLContext(glVersion);\n }\n // WebGL 2.0 enables texture floats without an extension.\n let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';\n const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';\n if (env().getNumber('WEBGL_VERSION') === 1) {\n const TEXTURE_FLOAT = 'OES_texture_float';\n const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';\n\n this.textureFloatExtension =\n webgl_util.getExtensionOrThrow(this.gl, this.debug, TEXTURE_FLOAT);\n if (webgl_util.hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {\n this.textureHalfFloatExtension = webgl_util.getExtensionOrThrow(\n this.gl, this.debug, TEXTURE_HALF_FLOAT);\n } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {\n throw new Error(\n 'GL context does not support half float textures, yet the ' +\n 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');\n }\n\n this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);\n if (webgl_util.hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {\n this.colorBufferHalfFloatExtension = webgl_util.getExtensionOrThrow(\n this.gl, this.debug, COLOR_BUFFER_HALF_FLOAT);\n } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {\n throw new Error(\n 'GL context does not support color renderable half floats, yet ' +\n 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');\n }\n } else {\n COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';\n if (webgl_util.hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {\n this.colorBufferFloatExtension =\n this.gl.getExtension(COLOR_BUFFER_FLOAT);\n } else if (webgl_util.hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {\n this.colorBufferHalfFloatExtension =\n this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);\n } else {\n throw new Error('GL context does not support color renderable floats');\n }\n }\n\n this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl, this.debug);\n this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl, this.debug);\n this.framebuffer = webgl_util.createFramebuffer(this.gl, this.debug);\n\n this.textureConfig =\n tex_util.getTextureConfig(this.gl, this.textureHalfFloatExtension);\n }\n\n private get debug(): boolean {\n return env().getBool('DEBUG');\n }\n\n public dispose() {\n if (this.disposed) {\n return;\n }\n if (this.program != null) {\n console.warn(\n 'Disposing a GPGPUContext that still has a bound WebGLProgram.' +\n ' This is probably a resource leak, delete the program with ' +\n 'GPGPUContext.deleteProgram before disposing.');\n }\n if (this.outputTexture != null) {\n console.warn(\n 'Disposing a GPGPUContext that still has a bound output matrix ' +\n 'texture. This is probably a resource leak, delete the output ' +\n 'matrix texture with GPGPUContext.deleteMatrixTexture before ' +\n 'disposing.');\n }\n const gl = this.gl;\n webgl_util.callAndCheck(gl, this.debug, () => gl.finish());\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.deleteFramebuffer(this.framebuffer));\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, null));\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null));\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.deleteBuffer(this.indexBuffer));\n this.disposed = true;\n }\n\n public createFloat32MatrixTexture(rows: number, columns: number):\n WebGLTexture {\n this.throwIfDisposed();\n return gpgpu_util.createFloat32MatrixTexture(\n this.gl, this.debug, rows, columns, this.textureConfig);\n }\n\n public createFloat16MatrixTexture(rows: number, columns: number):\n WebGLTexture {\n this.throwIfDisposed();\n return gpgpu_util.createFloat16MatrixTexture(\n this.gl, this.debug, rows, columns, this.textureConfig);\n }\n\n public createUnsignedBytesMatrixTexture(rows: number, columns: number):\n WebGLTexture {\n this.throwIfDisposed();\n return gpgpu_util.createUnsignedBytesMatrixTexture(\n this.gl, this.debug, rows, columns, this.textureConfig);\n }\n\n public uploadPixelDataToTexture(\n texture: WebGLTexture,\n pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement) {\n this.throwIfDisposed();\n gpgpu_util.uploadPixelDataToTexture(this.gl, this.debug, texture, pixels);\n }\n\n public uploadDenseMatrixToTexture(\n texture: WebGLTexture, width: number, height: number, data: TypedArray) {\n this.throwIfDisposed();\n gpgpu_util.uploadDenseMatrixToTexture(\n this.gl, this.debug, texture, width, height, data, this.textureConfig);\n }\n\n public createFloat16PackedMatrixTexture(rows: number, columns: number):\n WebGLTexture {\n this.throwIfDisposed();\n return gpgpu_util.createFloat16PackedMatrixTexture(\n this.gl, this.debug, rows, columns, this.textureConfig);\n }\n\n public createPackedMatrixTexture(rows: number, columns: number):\n WebGLTexture {\n this.throwIfDisposed();\n return gpgpu_util.createPackedMatrixTexture(\n this.gl, this.debug, rows, columns, this.textureConfig);\n }\n\n public deleteMatrixTexture(texture: WebGLTexture) {\n this.throwIfDisposed();\n if (this.outputTexture === texture) {\n webgl_util.unbindColorTextureFromFramebuffer(\n this.gl, this.debug, this.framebuffer);\n this.outputTexture = null;\n }\n webgl_util.callAndCheck(\n this.gl, this.debug, () => this.gl.deleteTexture(texture));\n }\n\n public downloadByteEncodedFloatMatrixFromOutputTexture(\n texture: WebGLTexture, rows: number, columns: number): Float32Array {\n return this.downloadMatrixDriver(\n texture,\n () => gpgpu_util.downloadByteEncodedFloatMatrixFromOutputTexture(\n this.gl, this.debug, rows, columns, this.textureConfig));\n }\n\n public downloadPackedMatrixFromBuffer(\n buffer: WebGLBuffer, batch: number, rows: number, columns: number,\n physicalRows: number, physicalCols: number): Float32Array {\n return gpgpu_util.downloadPackedMatrixFromBuffer(\n this.gl, buffer, batch, rows, columns, physicalRows, physicalCols,\n this.textureConfig);\n }\n\n public downloadFloat32MatrixFromBuffer(buffer: WebGLBuffer, size: number):\n Float32Array {\n return gpgpu_util.downloadFloat32MatrixFromBuffer(this.gl, buffer, size);\n }\n\n public createBufferFromTexture(\n texture: WebGLTexture, rows: number, columns: number): WebGLBuffer {\n this.bindTextureToFrameBuffer(texture);\n const result = gpgpu_util.createBufferFromOutputTexture(\n this.gl as WebGL2RenderingContext, this.debug, rows, columns,\n this.textureConfig);\n this.unbindTextureToFrameBuffer();\n return result;\n }\n\n public createAndWaitForFence(): Promise {\n const fenceContext = this.createFence(this.gl);\n return this.pollFence(fenceContext);\n }\n\n private createFence(gl: WebGLRenderingContext): FenceContext {\n let query: WebGLQuery|WebGLSync;\n let isFencePassed: () => boolean;\n\n if (env().getBool('WEBGL_FENCE_API_ENABLED')) {\n const gl2 = gl as WebGL2RenderingContext;\n\n const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);\n gl.flush();\n\n isFencePassed = () => {\n const status = gl2.clientWaitSync(sync, 0, 0);\n return status === gl2.ALREADY_SIGNALED ||\n status === gl2.CONDITION_SATISFIED;\n };\n\n query = sync;\n } else if (\n env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {\n query = this.beginQuery();\n this.endQuery();\n isFencePassed = () => this.isQueryAvailable(\n query,\n env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));\n } else {\n // If we have no way to fence, return true immediately. This will fire in\n // WebGL 1.0 when there is no disjoint query timer. In this case, because\n // the fence passes immediately, we'll immediately ask for a download of\n // the texture, which will cause the UI thread to hang.\n isFencePassed = () => true;\n }\n\n return {query, isFencePassed};\n }\n\n public downloadMatrixFromPackedTexture(\n texture: WebGLTexture, physicalRows: number,\n physicalCols: number): Float32Array {\n return this.downloadMatrixDriver(\n texture,\n () => gpgpu_util.downloadMatrixFromPackedOutputTexture(\n this.gl, this.debug, physicalRows, physicalCols));\n }\n\n private vertexAttrsAreBound = false;\n\n public createProgram(fragmentShaderSource: string): WebGLProgram {\n this.throwIfDisposed();\n const gl = this.gl;\n const fragmentShader: WebGLShader =\n webgl_util.createFragmentShader(gl, this.debug, fragmentShaderSource);\n const vertexShader: WebGLShader =\n gpgpu_util.createVertexShader(gl, this.debug);\n const program: WebGLProgram = webgl_util.createProgram(\n gl,\n this.debug,\n );\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.attachShader(program, vertexShader));\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.attachShader(program, fragmentShader));\n webgl_util.linkProgram(gl, this.debug, program);\n if (this.debug) {\n webgl_util.validateProgram(gl, this.debug, program);\n }\n if (!this.vertexAttrsAreBound) {\n this.setProgram(program);\n this.vertexAttrsAreBound = gpgpu_util.bindVertexProgramAttributeStreams(\n gl, this.debug, this.program, this.vertexBuffer);\n }\n return program;\n }\n\n public deleteProgram(program: WebGLProgram) {\n this.throwIfDisposed();\n if (program === this.program) {\n this.program = null;\n }\n if (program != null) {\n webgl_util.callAndCheck(\n this.gl, this.debug, () => this.gl.deleteProgram(program));\n }\n }\n\n public setProgram(program: WebGLProgram|null) {\n this.throwIfDisposed();\n this.program = program;\n if ((this.program != null) && this.debug) {\n webgl_util.validateProgram(this.gl, this.debug, this.program);\n }\n webgl_util.callAndCheck(\n this.gl, this.debug, () => this.gl.useProgram(program));\n }\n\n public getUniformLocation(\n program: WebGLProgram, uniformName: string,\n shouldThrow = true): WebGLUniformLocation {\n this.throwIfDisposed();\n if (shouldThrow) {\n return webgl_util.getProgramUniformLocationOrThrow(\n this.gl, this.debug, program, uniformName);\n } else {\n return webgl_util.getProgramUniformLocation(\n this.gl, program, uniformName);\n }\n }\n\n public getAttributeLocation(program: WebGLProgram, attribute: string):\n number {\n this.throwIfDisposed();\n return webgl_util.callAndCheck(\n this.gl, this.debug,\n () => this.gl.getAttribLocation(program, attribute));\n }\n\n public getUniformLocationNoThrow(program: WebGLProgram, uniformName: string):\n WebGLUniformLocation {\n this.throwIfDisposed();\n return this.gl.getUniformLocation(program, uniformName);\n }\n\n public setInputMatrixTexture(\n inputMatrixTexture: WebGLTexture, uniformLocation: WebGLUniformLocation,\n textureUnit: number) {\n this.throwIfDisposed();\n this.throwIfNoProgram();\n webgl_util.bindTextureToProgramUniformSampler(\n this.gl, this.debug, this.program, inputMatrixTexture, uniformLocation,\n textureUnit);\n }\n\n public setOutputMatrixTexture(\n outputMatrixTexture: WebGLTexture, rows: number, columns: number) {\n this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);\n }\n\n public setOutputPackedMatrixTexture(\n outputPackedMatrixTexture: WebGLTexture, rows: number, columns: number) {\n this.throwIfDisposed();\n const [width, height] =\n tex_util.getPackedMatrixTextureShapeWidthHeight(rows, columns);\n this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);\n }\n\n public setOutputMatrixWriteRegion(\n startRow: number, numRows: number, startColumn: number,\n numColumns: number) {\n this.setOutputMatrixWriteRegionDriver(\n startColumn, startRow, numColumns, numRows);\n }\n\n public setOutputPackedMatrixWriteRegion(\n startRow: number, numRows: number, startColumn: number,\n numColumns: number) {\n throw new Error('setOutputPackedMatrixWriteRegion not implemented.');\n }\n\n public debugValidate() {\n if (this.program != null) {\n webgl_util.validateProgram(this.gl, this.debug, this.program);\n }\n webgl_util.validateFramebuffer(this.gl);\n }\n\n public executeProgram() {\n this.throwIfDisposed();\n this.throwIfNoProgram();\n const gl = this.gl;\n if (this.debug) {\n this.debugValidate();\n }\n webgl_util.callAndCheck(\n gl, this.debug,\n () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0));\n }\n\n public blockUntilAllProgramsCompleted() {\n this.throwIfDisposed();\n webgl_util.callAndCheck(this.gl, this.debug, () => this.gl.finish());\n }\n\n private getQueryTimerExtension(): WebGL1DisjointQueryTimerExtension\n |WebGL2DisjointQueryTimerExtension {\n if (this.disjointQueryTimerExtension == null) {\n this.disjointQueryTimerExtension =\n webgl_util.getExtensionOrThrow(\n this.gl, this.debug,\n env().getNumber(\n 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?\n 'EXT_disjoint_timer_query_webgl2' :\n 'EXT_disjoint_timer_query') as\n WebGL1DisjointQueryTimerExtension |\n WebGL2DisjointQueryTimerExtension;\n }\n return this.disjointQueryTimerExtension;\n }\n\n private getQueryTimerExtensionWebGL2(): WebGL2DisjointQueryTimerExtension {\n return this.getQueryTimerExtension();\n }\n\n private getQueryTimerExtensionWebGL1(): WebGL1DisjointQueryTimerExtension {\n return this.getQueryTimerExtension() as WebGL1DisjointQueryTimerExtension;\n }\n\n beginQuery(): WebGLQuery {\n if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {\n const gl2 = this.gl as WebGL2RenderingContext;\n const ext = this.getQueryTimerExtensionWebGL2();\n\n const query = gl2.createQuery();\n gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);\n return query;\n }\n const ext = this.getQueryTimerExtensionWebGL1();\n const query = ext.createQueryEXT() as WebGLQuery;\n ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);\n return query;\n }\n\n endQuery() {\n if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {\n const gl2 = this.gl as WebGL2RenderingContext;\n const ext = this.getQueryTimerExtensionWebGL2();\n gl2.endQuery(ext.TIME_ELAPSED_EXT);\n return;\n }\n const ext = this.getQueryTimerExtensionWebGL1();\n ext.endQueryEXT(ext.TIME_ELAPSED_EXT);\n }\n\n public async waitForQueryAndGetTime(query: WebGLQuery): Promise {\n await util.repeatedTry(\n () => this.disposed || // while testing contexts are created / disposed\n // in rapid succession, so without this check we\n // may poll for the query timer indefinitely\n this.isQueryAvailable(\n query,\n env().getNumber(\n 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));\n return this.getQueryTime(\n query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));\n }\n\n private getQueryTime(query: WebGLQuery, queryTimerVersion: number): number {\n if (queryTimerVersion === 0) {\n return null;\n }\n\n if (queryTimerVersion === 2) {\n const gl2 = this.gl as WebGL2RenderingContext;\n\n const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);\n // Return milliseconds.\n return timeElapsedNanos / 1000000;\n } else {\n const ext = this.getQueryTimerExtensionWebGL1();\n\n const timeElapsedNanos =\n ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);\n // Return milliseconds.\n return timeElapsedNanos / 1000000;\n }\n }\n\n private isQueryAvailable(query: WebGLQuery, queryTimerVersion: number):\n boolean {\n if (queryTimerVersion === 0) {\n return true;\n }\n\n if (queryTimerVersion === 2) {\n const gl2 = this.gl as WebGL2RenderingContext;\n const ext = this.getQueryTimerExtensionWebGL2();\n\n const available =\n gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);\n if (this.disjoint == null) {\n this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);\n }\n\n return available && !this.disjoint;\n } else {\n const ext = this.getQueryTimerExtensionWebGL1();\n\n const available =\n ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);\n if (this.disjoint == null) {\n this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);\n }\n\n return available && !this.disjoint;\n }\n }\n\n pollFence(fenceContext: FenceContext) {\n return new Promise(resolve => {\n this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve());\n });\n }\n\n private itemsToPoll: PollItem[] = [];\n\n pollItems(): void {\n // Find the last query that has finished.\n const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn));\n for (let i = 0; i <= index; ++i) {\n const {resolveFn} = this.itemsToPoll[i];\n resolveFn();\n }\n this.itemsToPoll = this.itemsToPoll.slice(index + 1);\n }\n\n private addItemToPoll(isDoneFn: () => boolean, resolveFn: () => void) {\n this.itemsToPoll.push({isDoneFn, resolveFn});\n if (this.itemsToPoll.length > 1) {\n // We already have a running loop that polls.\n return;\n }\n // Start a new loop that polls.\n util.repeatedTry(() => {\n this.pollItems();\n // End the loop if no more items to poll.\n return this.itemsToPoll.length === 0;\n });\n }\n\n private bindTextureToFrameBuffer(texture: WebGLTexture) {\n this.throwIfDisposed();\n webgl_util.bindColorTextureToFramebuffer(\n this.gl, this.debug, texture, this.framebuffer);\n if (this.debug) {\n webgl_util.validateFramebuffer(this.gl);\n }\n }\n\n private unbindTextureToFrameBuffer() {\n if (this.outputTexture != null) {\n webgl_util.bindColorTextureToFramebuffer(\n this.gl, this.debug, this.outputTexture, this.framebuffer);\n if (this.debug) {\n webgl_util.validateFramebuffer(this.gl);\n }\n } else {\n webgl_util.unbindColorTextureFromFramebuffer(\n this.gl, this.debug, this.framebuffer);\n }\n }\n\n private downloadMatrixDriver(\n texture: WebGLTexture,\n downloadAndDecode: () => Float32Array): Float32Array {\n this.bindTextureToFrameBuffer(texture);\n const result = downloadAndDecode();\n this.unbindTextureToFrameBuffer();\n\n return result;\n }\n\n private setOutputMatrixTextureDriver(\n outputMatrixTextureMaybePacked: WebGLTexture, width: number,\n height: number) {\n this.throwIfDisposed();\n const gl = this.gl;\n webgl_util.bindColorTextureToFramebuffer(\n gl, this.debug, outputMatrixTextureMaybePacked, this.framebuffer);\n if (this.debug) {\n webgl_util.validateFramebuffer(gl);\n }\n this.outputTexture = outputMatrixTextureMaybePacked;\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.viewport(0, 0, width, height));\n webgl_util.callAndCheck(\n gl, this.debug, () => gl.scissor(0, 0, width, height));\n }\n\n private setOutputMatrixWriteRegionDriver(\n x: number, y: number, width: number, height: number) {\n this.throwIfDisposed();\n webgl_util.callAndCheck(\n this.gl, this.debug, () => this.gl.scissor(x, y, width, height));\n }\n\n private throwIfDisposed() {\n if (this.disposed) {\n throw new Error('Attempted to use disposed GPGPUContext.');\n }\n }\n\n private throwIfNoProgram() {\n if (this.program == null) {\n throw new Error('No GPU program is currently set.');\n }\n }\n}\n\ntype PollItem = {\n isDoneFn: () => boolean,\n resolveFn: () => void\n};\n\n/**\n * Finds the index of the last true element using linear search.\n * Note: We can't do binary search because Chrome expects us to explicitly\n * test all fences before download:\n * https://github.com/tensorflow/tfjs/issues/1145\n */\nexport function linearSearchLastTrue(arr: Array<() => boolean>): number {\n let i = 0;\n for (; i < arr.length; ++i) {\n const isDone = arr[i]();\n if (!isDone) {\n break;\n }\n }\n return i - 1;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../environment';\nimport {Tensor} from '../../tensor';\nimport {TypedArray} from '../../types';\nimport * as util from '../../util';\n\nimport {GPGPUContext} from './gpgpu_context';\nimport * as shader_compiler from './shader_compiler';\nimport {InputInfo, ShapeInfo} from './shader_compiler';\nimport {PackingScheme, TextureData, TextureUsage} from './tex_util';\n\nexport interface GPGPUProgram {\n variableNames: string[];\n outputShape: number[];\n userCode: string;\n /** If true, this program expects packed input textures. Defaults to false. */\n packedInputs?: boolean;\n /** If true, this program produces a packed texture. Defaults to false. */\n packedOutput?: boolean;\n /**\n * Affects what type of texture we allocate for the output. Defaults to\n * `TextureUsage.RENDER`.\n */\n outTexUsage?: TextureUsage;\n /**\n * The type of scheme to use when packing texels for the output values.\n * See `PackingScheme` for details. Defaults to `PackingScheme.SHARED_BATCH`.\n */\n outPackingScheme?: PackingScheme;\n}\n\nexport interface GPGPUBinary {\n webGLProgram: WebGLProgram;\n program: GPGPUProgram;\n uniformLocations: {[name: string]: WebGLUniformLocation};\n source: string;\n inShapeInfos: ShapeInfo[];\n outShapeInfo: ShapeInfo;\n infLoc: WebGLUniformLocation;\n nanLoc: WebGLUniformLocation;\n}\n\nexport interface TensorData {\n shape: number[];\n texData: TextureData;\n isUniform: boolean;\n // Available when we decide to upload as uniform instead of texture.\n uniformValues?: TypedArray;\n}\n\nexport function compileProgram(\n gpgpu: GPGPUContext, program: GPGPUProgram, inputs: TensorData[],\n output: TensorData): GPGPUBinary {\n const userCode = program.userCode;\n const inputInfos: InputInfo[] = inputs.map((input, i) => {\n const shapeInfo: ShapeInfo = {\n logicalShape: input.shape,\n texShape: input.isUniform ? null : input.texData.texShape,\n isUniform: input.isUniform,\n isPacked: input.isUniform ? false : input.texData.isPacked,\n flatOffset: null\n };\n if (input.texData != null && input.texData.slice != null &&\n input.texData.slice.flatOffset > 0) {\n shapeInfo.flatOffset = input.texData.slice.flatOffset;\n }\n return {name: program.variableNames[i], shapeInfo};\n });\n const inShapeInfos = inputInfos.map(x => x.shapeInfo);\n const outShapeInfo: ShapeInfo = {\n logicalShape: output.shape,\n texShape: output.texData.texShape,\n isUniform: false,\n isPacked: output.texData.isPacked,\n flatOffset: null\n };\n const source = shader_compiler.makeShader(\n inputInfos, outShapeInfo, userCode, program.packedInputs);\n\n const webGLProgram = gpgpu.createProgram(source);\n\n // Add special uniforms (NAN, INFINITY)\n let infLoc: WebGLUniformLocation = null;\n const nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);\n if (env().getNumber('WEBGL_VERSION') === 1) {\n infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);\n }\n\n // Add user-defined uniforms\n const uniformLocations: {[name: string]: WebGLUniformLocation} = {};\n for (let i = 0; i < program.variableNames.length; i++) {\n const varName = program.variableNames[i];\n const shouldThrow = false;\n uniformLocations[varName] =\n gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);\n uniformLocations[`offset${varName}`] =\n gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow);\n }\n\n return {\n program,\n source,\n webGLProgram,\n uniformLocations,\n inShapeInfos,\n outShapeInfo,\n infLoc,\n nanLoc,\n };\n}\n\nfunction validateBinaryAndProgram(\n shapeInfos: ShapeInfo[], inputs: TensorData[]) {\n if (shapeInfos.length !== inputs.length) {\n throw Error(\n `Binary was compiled with ${shapeInfos.length} inputs, but ` +\n `was executed with ${inputs.length} inputs`);\n }\n\n shapeInfos.forEach((s, i) => {\n const shapeA = s.logicalShape;\n const input = inputs[i];\n const shapeB = input.shape;\n\n if (!util.arraysEqual(shapeA, shapeB)) {\n throw Error(\n `Binary was compiled with different shapes than ` +\n `the current args. Shapes ${shapeA} and ${shapeB} must match`);\n }\n // The input is uploaded as uniform.\n if (s.isUniform && input.isUniform) {\n return;\n }\n\n const texShapeA = s.texShape;\n const texShapeB = input.isUniform ? null : input.texData.texShape;\n if (!util.arraysEqual(texShapeA, texShapeB)) {\n throw Error(\n `Binary was compiled with different texture shapes than the` +\n ` current args. Shape ${texShapeA} and ${texShapeB} must match`);\n }\n });\n}\n\nexport function runProgram(\n gpgpu: GPGPUContext, binary: GPGPUBinary, inputs: TensorData[],\n output: TensorData,\n customSetup?: (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) =>\n void): void {\n validateBinaryAndProgram(binary.inShapeInfos, inputs);\n validateBinaryAndProgram([binary.outShapeInfo], [output]);\n\n const outTex = output.texData.texture;\n const outTexShape = output.texData.texShape;\n if (output.texData.isPacked) {\n gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]);\n } else {\n gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]);\n }\n gpgpu.setProgram(binary.webGLProgram);\n\n // Set special uniforms (NAN, INFINITY)\n if (env().getNumber('WEBGL_VERSION') === 1) {\n if (binary.infLoc !== null) {\n gpgpu.gl.uniform1f(binary.infLoc, Infinity);\n }\n }\n if (binary.nanLoc !== null) {\n gpgpu.gl.uniform1f(binary.nanLoc, NaN);\n }\n\n // Set user-defined inputs\n inputs.forEach((input, i) => {\n const varName = binary.program.variableNames[i];\n const varLoc = binary.uniformLocations[varName];\n const varOffsetLoc = binary.uniformLocations[`offset${varName}`];\n\n if (varLoc == null) {\n // The compiler inferred that this variable is not used in this shader.\n return;\n }\n\n if (input.isUniform) {\n // Upload the values of the tensor as uniform.\n if (util.sizeFromShape(input.shape) < 2) {\n gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);\n } else {\n let vals = input.uniformValues;\n if (!(vals instanceof Float32Array)) {\n vals = new Float32Array(vals);\n }\n gpgpu.gl.uniform1fv(varLoc, vals);\n }\n return;\n }\n\n // If the input was sliced, upload the flat offset index.\n if (input.texData.slice != null && varOffsetLoc != null) {\n gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);\n }\n\n gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);\n });\n\n if (customSetup != null) {\n customSetup(gpgpu, binary.webGLProgram);\n }\n gpgpu.executeProgram();\n}\n\nexport function makeShaderKey(\n program: GPGPUProgram, inputs: TensorData[], output: TensorData): string {\n let keyInputs = '';\n inputs.concat(output).forEach(x => {\n const hasOffset = x.texData != null && x.texData.slice != null &&\n x.texData.slice.flatOffset > 0;\n const texShape = x.isUniform ? 'uniform' : x.texData.texShape;\n keyInputs += `${x.shape}_${texShape}_${hasOffset}`;\n });\n const keyUserCode = program.userCode;\n let key = program.constructor.name;\n // Fast string concat. See https://jsperf.com/string-concatenation/14.\n key += '_' + keyInputs + '_' + keyUserCode;\n return key;\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo} from '../../ops/conv_util';\nimport {getGlslDifferences} from './glsl_version';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class Im2ColPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[];\n userCode: string;\n\n constructor(\n outputShape: number[], inputShape: number[], convInfo: Conv2DInfo) {\n this.outputShape = outputShape;\n\n const {\n filterWidth,\n inChannels,\n strideWidth,\n strideHeight,\n padInfo,\n outWidth,\n dilationWidth,\n dilationHeight,\n dataFormat\n } = convInfo;\n const {left, top} = padInfo;\n const itemsPerBlockRow = inChannels * filterWidth;\n const glsl = getGlslDifferences();\n const isChannelsLast = dataFormat === 'channelsLast';\n const rowDim = isChannelsLast ? 0 : 1;\n const colDim = isChannelsLast ? 1 : 2;\n\n let unrolled = ``;\n\n for (let row = 0; row <= 1; row++) {\n for (let col = 0; col <= 1; col++) {\n unrolled += `\n blockIndex = rc.y + ${col};\n pos = rc.x + ${row};\n\n if(blockIndex < ${outputShape[1]} && pos < ${outputShape[0]}) {\n offsetY = int(blockIndex / (${outWidth})) * ${strideHeight} - ${\n top};\n d0 = offsetY + ${dilationHeight} * (pos / ${itemsPerBlockRow});\n\n if(d0 < ${inputShape[rowDim]} && d0 >= 0) {\n\n offsetX = int(mod(float(blockIndex), ${outWidth}.) * ${\n strideWidth}. - ${left}.);\n d1 = offsetX + ${dilationWidth} * (int(mod(float(pos), ${\n itemsPerBlockRow}.) / ${inChannels}.));\n\n if(d1 < ${inputShape[colDim]} && d1 >= 0) {\n\n ch = int(mod(float(pos), ${inChannels}.));\n\n if (${isChannelsLast}) {\n innerDims = vec2(d1, ch);\n result[${row * 2 + col}] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[${row * 2 + col}] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n `;\n }\n }\n\n this.userCode = `\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n ${unrolled}\n\n ${glsl.output} = result;\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class LRNProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(\n xShape: number[], radius: number, bias: number, alpha: number,\n beta: number) {\n const rad = radius;\n const maxD = xShape[3] - 1;\n this.outputShape = xShape;\n\n // optimize pow(bias + alpha * sum, -beta)\n // src: https://github.com/tensorflow/tensorflow/..\n // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..\n // tensorflow/core/kernels/mkl_lrn_op.cc#L320\n let powOperator;\n const basis = `float(${bias}) + float(${alpha}) * sum`;\n if (beta === 0.5) {\n powOperator = `inversesqrt(${basis})`;\n } else if (beta === 1.0) {\n powOperator = `1.0/(${basis})`;\n } else {\n powOperator = `exp(log(${basis}) * float(-${beta}));`;\n }\n\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -${rad}; j <= ${rad}; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= ${maxD}) {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * ${powOperator};\n setOutput(val);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class LRNGradProgram implements GPGPUProgram {\n variableNames = ['inputImage', 'outputImage', 'dy'];\n outputShape: number[] = [];\n userCode: string;\n depthRadius: number;\n bias: number;\n alpha: number;\n beta: number;\n depth: number;\n\n constructor(\n inputShape: number[], depthRadius: number, bias: number, alpha: number,\n beta: number) {\n this.outputShape = inputShape;\n this.depth = inputShape[3];\n this.depthRadius = depthRadius;\n this.bias = bias;\n this.alpha = alpha;\n this.beta = beta;\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < ${this.depth}; ++d) {\n int depthBegin = int(max(0.0, float(d - ${depthRadius})));\n int depthEnd = int(min(float(${this.depth}),\n float(d + ${depthRadius} + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = ${this.depth};\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(${alpha}) * norm + float(${bias});\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(${alpha})\n * float(${beta})\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * ${beta});\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google LLC All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class LRNPackedProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[] = [];\n userCode: string;\n packedInputs = true;\n packedOutput = true;\n\n constructor(\n xShape: number[], radius: number, bias: number, alpha: number,\n beta: number) {\n const rad = radius;\n const maxD = xShape[3] - 1;\n this.outputShape = xShape;\n\n // optimize pow(bias + alpha * sum, -beta)\n // src: https://github.com/tensorflow/tensorflow/..\n // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..\n // tensorflow/core/kernels/mkl_lrn_op.cc#L320\n let powOperator;\n const basis = `float(${bias}) + float(${alpha}) * sum`;\n if (beta === 0.5) {\n powOperator = `inversesqrt(${basis})`;\n } else if (beta === 1.0) {\n powOperator = `1.0/(${basis})`;\n } else {\n powOperator = `exp(log(${basis}) * float(-${beta}));`;\n }\n\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < ${this.outputShape[3]};\n bool hasNextRow = c < ${this.outputShape[2]};\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - ${rad};\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - ${rad}; j <= ${rad}; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * ${powOperator};\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class MaxPool2DBackpropProgram implements GPGPUProgram {\n variableNames = ['dy', 'maxPos'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv2DInfo) {\n this.outputShape = convInfo.inShape;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n\n const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;\n this.userCode = `\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * ${effectiveFilterWidth} + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n\nexport class MaxPool3DBackpropProgram implements GPGPUProgram {\n variableNames = ['dy', 'maxPos'];\n outputShape: number[];\n userCode: string;\n\n constructor(convInfo: Conv3DInfo) {\n this.outputShape = convInfo.inShape;\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n\n const lastIndex =\n effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;\n this.userCode = `\n const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < ${effectiveFilterDepth};\n wD += ${dilationDepth}) {\n float dyD = float(dyDCorner + wD) / ${strideDepth}.0;\n\n if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n float dyR = float(dyRCorner + wR) / ${strideHeight}.0;\n\n if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ${effectiveFilterWidth};\n wC += ${dilationWidth}) {\n float dyC = float(dyCCorner + wC) / ${strideWidth}.0;\n\n if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = ${lastIndex} -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +\n wR * ${effectiveFilterWidth} + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class MatMulPackedProgram implements GPGPUProgram {\n variableNames = ['matrixA', 'matrixB'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[];\n userCode: string;\n\n constructor(\n aShape: [number, number, number], outputShape: [number, number, number],\n transposeA = false, transposeB = false, addBias = false,\n activation: string = null, hasPreluActivation = false) {\n this.outputShape = outputShape;\n\n const sharedDim = transposeA ? aShape[1] : aShape[2];\n const sharedDimensionPacked = Math.ceil(sharedDim / 2);\n\n const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';\n const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';\n const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];\n const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];\n\n let activationSnippet = '', applyActivationSnippet = '';\n if (activation) {\n if (hasPreluActivation) {\n activationSnippet = `vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ${activation}\n }`;\n } else {\n activationSnippet = `vec4 activation(vec4 x) {\n ${activation}\n }`;\n }\n\n applyActivationSnippet = `result = activation(result);`;\n }\n\n const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';\n if (addBias) {\n this.variableNames.push('bias');\n }\n\n if (hasPreluActivation) {\n this.variableNames.push('preluActivationWeights');\n }\n\n this.userCode = `\n ${activationSnippet}\n\n const float sharedDimension = ${sharedDimensionPacked}.0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < ${sharedDimensionPacked}; i++) {\n vec4 a = getMatrixA(rc.x, ${aSample});\n vec4 b = getMatrixB(rc.x, ${bSample});\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (${aSwizzle[0]} * ${bSwizzle[0]});\n result += (${aSwizzle[1]} * ${bSwizzle[1]});\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n ${addBiasSnippet}\n\n ${applyActivationSnippet}\n\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class MultinomialProgram implements GPGPUProgram {\n variableNames = ['probs'];\n outputShape: number[];\n userCode: string;\n\n // Caching uniform location for speed.\n seedLoc: WebGLUniformLocation;\n\n constructor(batchSize: number, numOutcomes: number, numSamples: number) {\n this.outputShape = [batchSize, numSamples];\n\n this.userCode = `\n uniform float seed;\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < ${numOutcomes - 1}; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(${numOutcomes - 1}));\n }\n `;\n }\n\n getCustomSetupFunc(seed: number) {\n return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {\n if (this.seedLoc == null) {\n this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed');\n }\n gpgpu.gl.uniform1f(this.seedLoc, seed);\n };\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class OneHotProgram implements GPGPUProgram {\n variableNames = ['indices'];\n outputShape: number[];\n userCode: string;\n\n // Caching uniform location for speed.\n seedLoc: WebGLUniformLocation;\n\n constructor(\n numIndices: number, depth: number, onValue: number, offValue: number) {\n this.outputShape = [numIndices, depth];\n\n this.userCode = `\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(${offValue}), float(${onValue}),\n float(index == coords.y)));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getChannels} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class PackProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[];\n userCode: string;\n packedInputs = false;\n packedOutput = true;\n\n constructor(\n outputShape:\n number[]) { // TODO(https://github.com/tensorflow/tfjs/issues/893):\n // Only input / output 3D tensors.\n this.outputShape = outputShape;\n const rank = outputShape.length;\n\n if (rank === 0) {\n this.userCode = `\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n `;\n } else {\n const channels = getChannels('rc', rank);\n const dtype = getCoordsDataType(rank);\n const outOfBoundsCondition =\n getOutOfBoundsCondition(rank, outputShape, channels);\n const setup = getSetup(\n rank, outputShape[outputShape.length - 1],\n outputShape[outputShape.length - 2], channels);\n const output = getOutput(outputShape, channels);\n\n this.userCode = `\n void main() {\n ${dtype} rc = getOutputCoords();\n\n if(${outOfBoundsCondition}) {\n setOutput(vec4(0));\n } else {\n ${setup}\n\n setOutput(vec4(${output}));\n }\n }\n `;\n }\n }\n}\n\nfunction getSourceCoordsArr(rank: number, dims: string[]): string[] {\n const coords = [];\n\n for (let row = 0; row <= 1; row++) {\n for (let col = 0; col <= 1; col++) {\n let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;\n\n for (let d = 2; d < rank; d++) {\n coord = `${dims[dims.length - 1 - d]},` + coord;\n }\n\n coords.push(coord);\n }\n }\n return coords;\n}\n\nfunction getOutOfBoundsCondition(\n rank: number, shape: number[], dims: string[]): string {\n if (rank === 1) {\n return `rc > ${shape[0]}`;\n }\n\n let cond = '';\n for (let i = rank - 2; i < rank; i++) {\n cond += `${dims[i]} >= ${shape[i]}`;\n if (i < rank - 1) {\n cond += '||';\n }\n }\n\n return cond;\n}\n\nfunction getSetup(\n rank: number, cols: number, rows: number, dims: string[]): string {\n if (rank === 1) {\n return '';\n }\n\n const innerDims = dims.slice(-2);\n\n return `\n int r = ${innerDims[0]};\n int c = ${innerDims[1]};\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= ${cols};\n bool rEdge = rp1 >= ${rows};\n `;\n}\n\nfunction getOutput(shape: number[], dims: string[]): string {\n const rank = shape.length;\n const sourceCoords = getSourceCoordsArr(rank, dims);\n if (rank === 1) {\n return `getA(rc),\n rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),\n 0, 0`;\n }\n\n return `getA(${sourceCoords[0]}),\n cEdge ? 0. : getA(${sourceCoords[1]}),\n rEdge ? 0. : getA(${sourceCoords[2]}),\n rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class PadProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n xShape: number[], paddings: Array<[number, number]>,\n constantValue: number) {\n this.outputShape = paddings.map(\n (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);\n const rank = xShape.length;\n const type = getCoordsDataType(rank);\n\n const start = paddings.map(p => p[0]).join(',');\n const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');\n const unpackedCoords =\n ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);\n\n if (rank === 1) {\n this.userCode = `\n int start = ${start};\n int end = ${end};\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(float(${constantValue}));\n } else {\n setOutput(getX(outC - start));\n }\n }\n `;\n return;\n }\n this.userCode = `\n ${type} start = ${type}(${start});\n ${type} end = ${type}(${end});\n\n void main() {\n ${type} outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(float(${constantValue}));\n } else {\n ${type} coords = outC - start;\n setOutput(getX(${unpackedCoords}));\n }\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getChannels} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class PadPackedProgram implements GPGPUProgram {\n variableNames = ['x'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[];\n userCode: string;\n\n constructor(\n xShape: number[], paddings: Array<[number, number]>,\n constantValue: number) {\n this.outputShape = paddings.map(\n (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);\n const rank = xShape.length;\n const dtype = getCoordsDataType(rank);\n\n const start = paddings.map(p => p[0]).join(',');\n const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');\n const coords = getChannels('rc', rank);\n const source = getChannels('source', rank);\n const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;\n const innerDims =\n rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;\n\n const componentSetup = [\n `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;\n if(${cLimit}) {\n `,\n rank === 1 ? '' : `}\n rc = outputLoc;\n ${coords[rank - 2]} += 1;\n if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,\n rank === 1 ? '' : ` ${coords[rank - 1]} += 1;\n if(${cLimit}) {`\n ];\n\n const paddingArea = rank === 1 ?\n 'rc < start || rc >= end' :\n 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';\n let mainLoop = '';\n for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {\n mainLoop += `\n ${componentSetup[i]}\n if (${paddingArea}) {\n result[${i}] = float(${constantValue});\n } else {\n ${dtype} source = rc - start;\n result[${i}] = getChannel(getX(${source.join()}), ${innerDims});\n }\n `;\n }\n mainLoop += (rank === 1 ? `} ` : `}}`);\n\n this.userCode = `\n const ${dtype} start = ${dtype}(${start});\n const ${dtype} end = ${dtype}(${end});\n\n void main() {\n ${dtype} outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n ${mainLoop}\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo} from '../../ops/conv_util';\nimport {Conv3DInfo} from '../../ops/conv_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class Pool2DProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n convInfo: Conv2DInfo, poolType: 'max'|'avg', computePositions: boolean,\n flattenPositions = false, includeBatchInIndex = false) {\n if (poolType === 'avg' && computePositions) {\n throw new Error('Cannot compute positions for average pool.');\n }\n\n const filterWidth = convInfo.filterWidth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n this.outputShape = convInfo.outShape;\n\n const isAvgPool = poolType === 'avg';\n const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${\n convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;\n const flattenPositionStr =\n `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;\n\n let initializationValue = '0.0';\n if (!isAvgPool) {\n // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.\n initializationValue = '-1.0 / 1e-20';\n }\n\n if (computePositions) {\n const compareOp = '>=';\n\n this.userCode = `\n const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int wC = 0; wC < ${effectiveFilterWidth};\n wC += ${dilationWidth}) {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value ${compareOp} currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = ${\n flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :\n flattenPositionStr) :\n `wR * ${effectiveFilterWidth} + wC`};\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n `;\n return;\n }\n\n const compareOp = 'max';\n\n let returnValue = `${poolType}(${poolType}(${poolType}(` +\n 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';\n if (poolType === 'avg') {\n returnValue = `avgValue / count`;\n }\n\n const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;\n const filterWidthVec4Remainder = filterWidth % 4;\n\n const updateSnippet = `\n if (${isAvgPool}) {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = ${compareOp}(values, minMaxValue);\n }\n `;\n\n this.userCode = `\n const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n const ivec2 pads = ivec2(${padTop}, ${padLeft});\n const float initializationValue = ${initializationValue};\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(${initializationValue});\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {\n int xC = xCCorner + wC * ${dilationWidth};\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ${dilationWidth}, d),\n getValue(batch, xR, xC + 2 * ${dilationWidth}, d),\n getValue(batch, xR, xC + 3 * ${dilationWidth}, d)\n );\n\n ${updateSnippet}\n }\n\n int xC = xCCorner + ${filterWidthNearestVec4};\n if (${filterWidthVec4Remainder === 1}) {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ${updateSnippet}\n } else if (${filterWidthVec4Remainder === 2}) {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ${dilationWidth}, d),\n initializationValue,\n initializationValue\n );\n\n ${updateSnippet}\n } else if (${filterWidthVec4Remainder === 3}) {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ${dilationWidth}, d),\n getValue(batch, xR, xC + 2 * ${dilationWidth}, d),\n initializationValue\n );\n\n ${updateSnippet}\n }\n }\n setOutput(${returnValue});\n }\n `;\n }\n}\n\nexport class Pool3DProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n convInfo: Conv3DInfo, poolType: 'max'|'avg', computePositions: boolean,\n flattenPositions = false, includeBatchInIndex = false) {\n if (poolType === 'avg' && computePositions) {\n throw new Error('Cannot compute positions for average pool.');\n }\n\n const filterWidth = convInfo.filterWidth;\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n const padFront = convInfo.padInfo.front;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n this.outputShape = convInfo.outShape;\n\n const isAvgPool = poolType === 'avg';\n\n let initializationValue = '0.0';\n if (!isAvgPool) {\n // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.\n initializationValue = '-1.0 / 1e-20';\n }\n\n if (computePositions) {\n const compareOp = '>=';\n\n this.userCode = `\n const ivec3 strides =\n ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});\n const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < ${effectiveFilterDepth};\n wD += ${dilationDepth}) {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= ${convInfo.inDepth}) {\n continue;\n }\n\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int wC = 0; wC < ${effectiveFilterWidth};\n wC += ${dilationWidth}) {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value ${compareOp} currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = ${\n flattenPositions ?\n (includeBatchInIndex ?\n `(((batch * ${convInfo.inDepth} + xD) * ${\n convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${\n convInfo.inChannels} + ch` :\n `((xD * ${convInfo.inHeight} + xR) * ${\n convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :\n `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +\n wR * ${effectiveFilterWidth} + wC`};\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n `;\n return;\n }\n\n const compareOp = 'max';\n\n let returnValue = `${poolType}(${poolType}(${poolType}(` +\n 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';\n if (poolType === 'avg') {\n returnValue = `avgValue / count`;\n }\n\n const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;\n const filterWidthVec4Remainder = filterWidth % 4;\n\n const updateSnippet = `\n if (${isAvgPool}) {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = ${compareOp}(values, minMaxValue);\n }\n `;\n\n this.userCode = `\n const ivec3 strides =\n ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});\n const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n const float initializationValue = ${initializationValue};\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= ${convInfo.inWidth}) {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(${initializationValue});\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < ${effectiveFilterDepth};\n wD += ${dilationDepth}) {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= ${convInfo.inDepth}) {\n continue;\n }\n\n for (int wR = 0; wR < ${effectiveFilterHeight};\n wR += ${dilationHeight}) {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ${convInfo.inHeight}) {\n continue;\n }\n\n for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {\n int xC = xCCorner + wC * ${dilationWidth};\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ${dilationWidth}, ch),\n getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),\n getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)\n );\n\n ${updateSnippet}\n }\n\n int xC = xCCorner + ${filterWidthNearestVec4};\n if (${filterWidthVec4Remainder === 1}) {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ${updateSnippet}\n } else if (${filterWidthVec4Remainder === 2}) {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ${dilationWidth}, ch),\n initializationValue,\n initializationValue\n );\n\n ${updateSnippet}\n } else if (${filterWidthVec4Remainder === 3}) {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ${dilationWidth}, ch),\n getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),\n initializationValue\n );\n\n ${updateSnippet}\n }\n }\n setOutput(${returnValue});\n }\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ReduceInfo} from '../../ops/reduce_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ReduceProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n reduceInfo: ReduceInfo,\n reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod') {\n const windowSize = reduceInfo.windowSize;\n const batchSize = reduceInfo.batchSize;\n const inSize = reduceInfo.inSize;\n const outSize = Math.ceil(inSize / windowSize);\n this.outputShape = [batchSize, outSize];\n\n let initializationValue = '0.0';\n let compareOp = ``;\n\n if (reduceType === 'prod') {\n initializationValue = '1.0';\n } else if (reduceType === 'min') {\n // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.\n initializationValue = '1.0 / 1e-20';\n compareOp = `min`;\n } else if (reduceType === 'max') {\n // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.\n initializationValue = '-1.0 / 1e-20';\n compareOp = `max`;\n }\n\n let returnValue = `${reduceType}(${reduceType}(${reduceType}(` +\n 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';\n\n if (reduceType === 'sum') {\n returnValue = `sumValue`;\n } else if (reduceType === 'prod') {\n returnValue = `prodValue`;\n } else if (reduceType === 'all') {\n returnValue = `allValue`;\n } else if (reduceType === 'any') {\n returnValue = `anyValue`;\n }\n\n const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;\n const windowSizeVec4Remainder = windowSize % 4;\n\n let updateSnippet = `\n if (${reduceType === 'sum'}) {\n sumValue += dot(values, ones);\n } else if (${reduceType === 'prod'}) {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = ${compareOp}(values, minMaxValue);\n }\n `;\n\n let vecType = `vec4`;\n\n if (reduceType === 'all') {\n initializationValue = '1.0';\n updateSnippet = `\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n `;\n vecType = `bvec4`;\n } else if (reduceType === 'any') {\n initializationValue = '0.0';\n updateSnippet = `\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n `;\n vecType = `bvec4`;\n }\n\n let checkOutOfBounds = '';\n if (inSize % windowSize > 0) {\n checkOutOfBounds = `\n if (inIdx < 0 || inIdx >= ${inSize}) {\n return initializationValue;\n }\n `;\n }\n this.userCode = `\n const float initializationValue = ${initializationValue};\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n ${checkOutOfBounds}\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ${windowSize};\n\n vec4 minMaxValue = vec4(${initializationValue});\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {\n int inIdx = inOffset + i;\n ${vecType} values = ${vecType}(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n ${updateSnippet}\n }\n\n int inIdx = inOffset + ${windowSizeNearestVec4};\n if (${windowSizeVec4Remainder === 1}) {\n ${vecType} values = ${vecType}(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ${updateSnippet}\n } else if (${windowSizeVec4Remainder === 2}) {\n ${vecType} values = ${vecType}(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n ${updateSnippet}\n } else if (${windowSizeVec4Remainder === 3}) {\n ${vecType} values = ${vecType}(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n ${updateSnippet}\n }\n setOutput(${returnValue});\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport * as shader_util from './shader_compiler_util';\n\nexport class ReshapePackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[];\n userCode: string;\n\n constructor(outputShape: [number, number, number], inputShape: [\n number, number, number\n ]) {\n this.outputShape = outputShape;\n\n let mainLoop = ``;\n for (let i = 0; i < 4; i++) {\n let thisRC = `thisRC = rc;`;\n if (i % 2 === 1) {\n thisRC += `thisRC.z += 1;`;\n }\n if (i > 1) {\n thisRC += `thisRC.y += 1;`;\n }\n\n mainLoop += `\n ${thisRC}\n ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[${i}] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n ${i > 0 ? '}' : ''}\n `;\n }\n\n this.userCode = `\n ${getReshapedInputCoords(inputShape)}\n ${shader_util.getFlatIndexFrom3D(outputShape)}\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = ${outputShape[1]};\n int cols = ${outputShape[2]};\n\n ${mainLoop}\n\n setOutput(result);\n }\n `;\n }\n}\n\nfunction getReshapedInputCoords(shape: [number, number, number]): string {\n const coordsFromIndexSnippet =\n shader_util.getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);\n\n return `\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n ${coordsFromIndexSnippet}\n return ivec3(r, c, d);\n }\n `;\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor4D} from '../../tensor';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ResizeBilinearBackpropProgram implements GPGPUProgram {\n variableNames = ['dy'];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(dy: Tensor4D, x: Tensor4D, alignCorners: boolean) {\n this.outputShape = x.shape;\n const [, xHeight, xWidth, ] = x.shape;\n const [, yHeight, yWidth] = dy.shape;\n\n // In the backwards pass, we want to find the pixels that were generated for\n // each pixel in the input image the forward pass and add the corresponding\n // coefficient from dy to the gradient (with some interpolation).\n\n const effectiveXSize: [number, number] = [\n (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,\n (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth\n ];\n\n const effectiveYSize: [number, number] = [\n (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,\n (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth\n ];\n\n const heightScale = effectiveXSize[0] / effectiveYSize[0];\n const widthScale = effectiveXSize[1] / effectiveYSize[1];\n\n const invHeightScale = 1 / heightScale;\n const invWidthScale = 1 / widthScale;\n\n // This defines the size of the window of values around a particular\n // index in dy that we want to search for contributions to dx.\n const winHeight = (Math.ceil(invHeightScale) * 2) + 2;\n const winWidth = (Math.ceil(invWidthScale) * 2) + 2;\n\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(${heightScale});\n const float widthScale = float(${widthScale});\n\n const float invHeightScale = float(${invHeightScale});\n const float invWidthScale = float(${invWidthScale});\n\n const int winHeight = int(${winHeight});\n const int winWidth = int(${winWidth});\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= ${yHeight}) {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= ${yWidth}) {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ResizeBilinearProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(\n inputShape: [number, number, number, number], newHeight: number,\n newWidth: number, alignCorners: boolean) {\n const [batch, oldHeight, oldWidth, depth] = inputShape;\n this.outputShape = [batch, newHeight, newWidth, depth];\n\n const effectiveInSize: [number, number] = [\n (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,\n (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth\n ];\n\n const effectiveOutSize: [number, number] = [\n (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,\n (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth\n ];\n\n this.userCode = `\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n ${effectiveInSize[0] / effectiveOutSize[0]},\n ${effectiveInSize[1] / effectiveOutSize[1]});\n const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ResizeBilinearPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[] = [];\n userCode: string;\n\n constructor(\n inputShape: [number, number, number, number], newHeight: number,\n newWidth: number, alignCorners: boolean) {\n const [batch, oldHeight, oldWidth, depth] = inputShape;\n this.outputShape = [batch, newHeight, newWidth, depth];\n\n const effectiveInSize: [number, number] = [\n (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,\n (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth\n ];\n\n const effectiveOutSize: [number, number] = [\n (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,\n (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth\n ];\n\n this.userCode = `\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n ${effectiveInSize[0] / effectiveOutSize[0]},\n ${effectiveInSize[1] / effectiveOutSize[1]},\n ${effectiveInSize[1] / effectiveOutSize[1]});\n const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,\n ${oldWidth}.0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < ${depth - 1};\n bool hasNextRow = coords.z < ${newWidth - 1};\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor4D} from '../../tensor';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ResizeNearestNeigborBackpropProgram implements GPGPUProgram {\n variableNames = ['dy'];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(dy: Tensor4D, x: Tensor4D, alignCorners: boolean) {\n this.outputShape = x.shape;\n const [, xHeight, xWidth, ] = x.shape;\n const [, yHeight, yWidth] = dy.shape;\n\n // In the backwards pass, we want to find the pixels that were generated for\n // each pixel in the input image the forward pass and add the corresponding\n // coefficient from dy to the gradient (with some interpolation).\n\n const effectiveXSize: [number, number] = [\n (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,\n (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth\n ];\n\n const effectiveYSize: [number, number] = [\n (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,\n (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth\n ];\n\n const heightScale = effectiveXSize[0] / effectiveYSize[0];\n const widthScale = effectiveXSize[1] / effectiveYSize[1];\n\n const invHeightScale = 1 / heightScale;\n const invWidthScale = 1 / widthScale;\n\n // This defines the size of the window of values around a particular\n // index in dy that we want to search for contributions to dx.\n const winHeight = (Math.ceil(invHeightScale) * 2) + 2;\n const winWidth = (Math.ceil(invWidthScale) * 2) + 2;\n\n this.userCode = `\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(${heightScale});\n const float widthScale = float(${widthScale});\n\n const float invHeightScale = float(${invHeightScale});\n const float invWidthScale = float(${invWidthScale});\n\n const int winHeight = int(${winHeight});\n const int winWidth = int(${winWidth});\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= ${yHeight}) {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= ${yWidth}) {\n continue;\n }\n\n float sourceFracRow =\n float(${effectiveXSize[0]}) *\n (float(dyR) / float(${effectiveYSize[0]}));\n\n float sourceFracCol =\n float(${effectiveXSize[1]}) *\n (float(dyC) / float(${effectiveYSize[1]}));\n\n int sourceNearestRow = int(min(\n float(int(${xHeight}) - 1),\n ${alignCorners} ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(${xWidth}) - 1),\n ${alignCorners} ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class ResizeNearestNeighborProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[] = [];\n userCode: string;\n\n constructor(\n inputShape: [number, number, number, number], newHeight: number,\n newWidth: number, alignCorners: boolean) {\n const [batch, oldHeight, oldWidth, depth] = inputShape;\n this.outputShape = [batch, newHeight, newWidth, depth];\n\n const effectiveInSize: [number, number] = [\n (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,\n (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth\n ];\n\n const effectiveOutSize: [number, number] = [\n (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,\n (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth\n ];\n\n // When align corners is false, we rounds the value with floor.\n const roundBase = alignCorners ? '0.5' : '0.0';\n\n this.userCode = `\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n ${effectiveInSize[0] / effectiveOutSize[0]},\n ${effectiveInSize[1] / effectiveOutSize[1]});\n const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));\n\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ReverseProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(xShape: number[], axis: number[]) {\n const rank = xShape.length;\n if (rank > 4) {\n throw new Error(\n `WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);\n }\n this.outputShape = xShape;\n\n if (rank === 1) {\n this.userCode = `\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(${xShape[0]} - coord - 1));\n }\n `;\n return;\n }\n const getInCoord = (i: number) => {\n if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {\n return `${xShape[i]} - coords[${i}] - 1`;\n }\n return `coords[${i}]`;\n };\n const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');\n const type = getCoordsDataType(rank);\n\n this.userCode = `\n void main() {\n ${type} coords = getOutputCoords();\n setOutput(getX(${inCoords}));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2019 Google LLC All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getChannels} from '../packing_util';\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ReversePackedProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n packedInputs = true;\n packedOutput = true;\n\n constructor(xShape: number[], axis: number[]) {\n const rank = xShape.length;\n if (rank > 4) {\n throw new Error(\n `WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);\n }\n this.outputShape = xShape;\n const channels = getChannels('rc', rank);\n const nextColumn =\n `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`;\n const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`;\n const type = getCoordsDataType(rank);\n if (rank === 1) {\n this.userCode = `\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(${xShape[0]} - rc - 1),\n ${xShape[0]} - rc - 1);\n if(${nextColumn}){\n result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),\n ${xShape[0]} - (rc + 1) - 1);\n }\n setOutput(result);\n }\n `;\n } else {\n this.userCode = `\n void main() {\n ${type} rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = ${getR(channels.slice())};\n if(${nextColumn}){\n result.g = ${getG(channels.slice())};\n }\n if(${nextRow}) {\n result.b = ${getB(channels.slice())};\n if(${nextColumn}) {\n result.a = ${getA(channels.slice())};\n }\n }\n setOutput(result);\n }\n `;\n }\n\n function getR(channels: string[]): string {\n return getChannel(channels);\n }\n\n function getG(channels: string[]): string {\n channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;\n return getChannel(channels);\n }\n\n function getB(channels: string[]): string {\n channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;\n return getChannel(channels);\n }\n\n function getA(channels: string[]): string {\n channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;\n channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;\n return getChannel(channels);\n }\n\n function getChannel(channels: string[]): string {\n const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels));\n const inCoords = inCoordsArray.join(',');\n const innerDims = inCoordsArray.slice(-2).join(',');\n return `getChannel(getX(${inCoords}), vec2(${innerDims}))`;\n }\n\n function getInCoord(i: number, channels1: string[]): string {\n if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {\n return `${xShape[i]} - ${channels1[i]} - 1`;\n } else {\n return `${channels1[i]}`;\n }\n }\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ScatterProgram implements GPGPUProgram {\n variableNames = ['updates', 'indices', 'defaultValue'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n updateSize: number, sliceDim: number, indicesRank: number,\n updatesRank: number, strides: number[], shape: number[],\n summingDupeIndex = true) {\n this.outputShape = shape;\n const stridesType = getCoordsDataType(strides.length);\n const dtype = getCoordsDataType(shape.length);\n let indicesString = '';\n if (indicesRank === 1) {\n indicesString = 'i';\n } else if (indicesRank === 2) {\n indicesString = 'i, j';\n }\n const indicesSnippet = `getIndices(${indicesString})`;\n\n let updatesString = '';\n if (updatesRank === 1) {\n updatesString = 'i';\n } else if (updatesRank === 2) {\n updatesString = 'i, coords[1]';\n }\n const updatesSnippet = `getUpdates(${updatesString})`;\n\n const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';\n this.userCode = `\n ${stridesType} strides = ${stridesType}(${strides});\n\n void main() {\n ${dtype} coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < ${updateSize}; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < ${sliceDim}; j++) {\n int index = round(${indicesSnippet});\n flattenedIndex += index * ${strideString};\n }\n if (flattenedIndex == coords[0]) {\n sum += ${updatesSnippet};\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {SegOpInfo} from '../../ops/segment_util';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class SegmentOpProgram implements GPGPUProgram {\n variableNames = ['x', 'segmentIds'];\n outputShape: number[];\n userCode: string;\n\n constructor(segOpInfo: SegOpInfo, segOpType: 'unsortedSegmentSum') {\n const windowSize = segOpInfo.windowSize;\n const batchSize = segOpInfo.batchSize;\n const inSize = segOpInfo.inSize;\n const numSegments = segOpInfo.numSegments;\n const outSize = numSegments * Math.ceil(inSize / windowSize);\n this.outputShape = [batchSize, outSize];\n\n const initializationValue = '0.0';\n const returnValue = `sumValue`;\n\n const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;\n const windowSizeVec4Remainder = windowSize % 4;\n\n const updateSnippet = `\n sumValue += dot(values, segFilter);\n `;\n\n let checkValueOutOfBounds = '';\n if (inSize % windowSize > 0) {\n checkValueOutOfBounds = `\n if (inIdx < 0 || inIdx >= ${inSize}) {\n return initializationValue;\n }\n `;\n }\n\n let checkSegmentIdOutOfBounds = '';\n if (inSize % windowSize > 0) {\n checkSegmentIdOutOfBounds = `\n if (inIdx < 0 || inIdx >= ${inSize}) {\n return -1.0;\n }\n `;\n }\n\n this.userCode = `\n const float initializationValue = ${initializationValue};\n\n float getValue(int batch, int inIdx) {\n ${checkValueOutOfBounds}\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n ${checkSegmentIdOutOfBounds}\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n ${numSegments})) * float(${windowSize}));\n int currentSeg = int(mod(float(outIdx), float(${numSegments})));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n ${updateSnippet}\n }\n\n int inIdx = inOffset + ${windowSizeNearestVec4};\n if (${windowSizeVec4Remainder === 1}) {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n ${updateSnippet}\n } else if (${windowSizeVec4Remainder === 2}) {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n ${updateSnippet}\n } else if (${windowSizeVec4Remainder === 3}) {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n ${updateSnippet}\n }\n setOutput(${returnValue});\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class SelectProgram implements GPGPUProgram {\n variableNames = ['c', 'a', 'b'];\n outputShape: number[];\n userCode: string;\n\n constructor(cRank: number, shape: number[], rank: number) {\n this.outputShape = shape;\n\n let cCoords;\n let abCoords;\n if (rank > 4) {\n throw Error(`Where for rank ${rank} is not yet supported`);\n }\n\n if (rank === 1) {\n abCoords = `resRC`;\n cCoords = `resRC`;\n } else {\n const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];\n const cCoordVars = [];\n const abCoordVars = [];\n for (let i = 0; i < shape.length; i++) {\n abCoordVars.push(`${currentCoords[i]}`);\n if (i < cRank) {\n cCoordVars.push(`${currentCoords[i]}`);\n }\n }\n cCoords = cCoordVars.join();\n abCoords = abCoordVars.join();\n }\n\n const dtype = getCoordsDataType(rank);\n\n this.userCode = `\n void main() {\n ${dtype} resRC = getOutputCoords();\n float cVal = getC(${cCoords});\n if (cVal >= 1.0) {\n setOutput(getA(${abCoords}));\n } else {\n setOutput(getB(${abCoords}));\n }\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class SliceProgram implements GPGPUProgram {\n variableNames = ['source'];\n outputShape: number[];\n userCode: string;\n rank: number;\n\n // Caching uniform location for speed.\n startLoc: WebGLUniformLocation;\n\n constructor(destSize: number[]) {\n this.outputShape = destSize;\n this.rank = destSize.length;\n\n const dtype = getCoordsDataType(this.rank);\n const uniformPart = `uniform int start[${this.rank}];`;\n const sourceCoords = getCoords(this.rank);\n\n let body: string;\n const coordSum = destSize.map((_, i) => {\n return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`;\n });\n body = `\n ${dtype} sourceLoc;\n ${dtype} coords = getOutputCoords();\n ${coordSum.join('\\n')}\n `;\n this.userCode = `\n ${uniformPart}\n void main() {\n ${body}\n setOutput(getSource(${sourceCoords}));\n }\n `;\n }\n\n getCustomSetupFunc(start: number[]) {\n if (start.length !== this.rank) {\n throw Error(\n `The rank (${this.rank}) of the program must match the ` +\n `length of start (${start.length})`);\n }\n return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {\n if (this.startLoc == null) {\n this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start');\n if (this.startLoc == null) {\n // This means the compiler has optimized and realized it doesn't need\n // the uniform.\n return;\n }\n }\n gpgpu.gl.uniform1iv(this.startLoc, start);\n };\n }\n}\n\nconst coords = ['x', 'y', 'z', 'w', 'u', 'v'];\n\nfunction getCoords(rank: number): string {\n if (rank === 1) {\n return 'sourceLoc';\n } else if (rank <= 6) {\n return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(',');\n } else {\n throw Error(`Slicing for rank ${rank} is not yet supported`);\n }\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getChannels} from '../packing_util';\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class SlicePackedProgram implements GPGPUProgram {\n variableNames = ['source'];\n packedInputs = true;\n packedOutput = true;\n outputShape: number[];\n userCode: string;\n rank: number;\n\n // Caching uniform location for speed.\n startLoc: WebGLUniformLocation;\n\n constructor(destSize: number[]) {\n this.outputShape = destSize;\n this.rank = destSize.length;\n\n const dtype = getCoordsDataType(this.rank);\n const coords = getChannels('coords', this.rank);\n const sourceLoc = getChannels('sourceLoc', this.rank);\n\n const innerDims =\n this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`;\n const getChannel =\n `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;\n const upperRow = `\n result.x = ${getChannel};\n if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {\n ++${sourceLoc[this.rank - 1]};\n result.y = ${getChannel};\n --${sourceLoc[this.rank - 1]};\n }\n `;\n const lowerRow = this.rank === 1 ? '' : `\n --${coords[this.rank - 1]};\n if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) {\n ++${sourceLoc[this.rank - 2]};\n result.z = ${getChannel};\n if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {\n ++${sourceLoc[this.rank - 1]};\n result.w = ${getChannel};\n }\n }\n `;\n\n const sourceLocSetup = this.rank <= 4 ?\n `sourceLoc = coords +\n ${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` :\n destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`)\n .join('\\n');\n this.userCode = `\n uniform int start[${this.rank}];\n void main() {\n ${dtype} coords = getOutputCoords();\n ${dtype} sourceLoc;\n ${sourceLocSetup}\n vec4 result = vec4(0.);\n ${upperRow}\n ${lowerRow}\n setOutput(result);\n }\n `;\n }\n\n getCustomSetupFunc(start: number[]) {\n if (start.length !== this.rank) {\n throw Error(\n `The rank (${this.rank}) of the program must match the ` +\n `length of start (${start.length})`);\n }\n return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {\n if (this.startLoc == null) {\n this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start');\n if (this.startLoc == null) {\n // This means the compiler has optimized and realized it doesn't need\n // the uniform.\n return;\n }\n }\n gpgpu.gl.uniform1iv(this.startLoc, start);\n };\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class StridedSliceProgram implements GPGPUProgram {\n variableNames = ['x'];\n outputShape: number[];\n userCode: string;\n\n constructor(\n begin: number[], strides: number[], size: number[]) {\n this.outputShape = size;\n const rank = size.length;\n const inputDtype = getCoordsDataType(size.length);\n const dtype = getCoordsDataType(size.length);\n\n let newCoords = '';\n if (rank === 1) {\n newCoords = 'coords * strides + begin';\n } else {\n let outputAxis = 0;\n newCoords =\n size.map((_, i) => {\n outputAxis++;\n return size.length === 1 ?\n `coords * strides[${i}] + begin[${i}]` :\n `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;\n })\n .join(',');\n }\n\n this.userCode = `\n ${inputDtype} begin = ${inputDtype}(${begin});\n ${inputDtype} strides = ${inputDtype}(${strides});\n\n void main() {\n ${dtype} coords = getOutputCoords();\n setOutput(getX(${newCoords}));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../environment';\n\nimport {GPGPUContext} from './gpgpu_context';\nimport {PhysicalTextureType, TextureUsage} from './tex_util';\n\nexport class TextureManager {\n private numUsedTextures = 0;\n private numFreeTextures = 0;\n private freeTextures: {[shape: string]: WebGLTexture[]} = {};\n private logEnabled = false;\n private usedTextures: {[shape: string]: WebGLTexture[]} = {};\n\n constructor(private gpgpu: GPGPUContext) {}\n\n acquireTexture(\n shapeRC: [number, number], usage: TextureUsage,\n isPacked: boolean): WebGLTexture {\n const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);\n\n const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);\n if (!(shapeKey in this.freeTextures)) {\n this.freeTextures[shapeKey] = [];\n }\n if (!(shapeKey in this.usedTextures)) {\n this.usedTextures[shapeKey] = [];\n }\n\n if (this.freeTextures[shapeKey].length > 0) {\n this.numFreeTextures--;\n this.numUsedTextures++;\n this.log();\n const newTexture = this.freeTextures[shapeKey].shift();\n this.usedTextures[shapeKey].push(newTexture);\n return newTexture;\n }\n this.numUsedTextures++;\n this.log();\n\n let newTexture: WebGLTexture;\n if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {\n newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);\n } else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {\n newTexture =\n this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);\n } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {\n newTexture =\n this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);\n } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {\n newTexture =\n this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);\n\n } else if (\n physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {\n newTexture =\n this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);\n }\n this.usedTextures[shapeKey].push(newTexture);\n\n return newTexture;\n }\n\n releaseTexture(\n texture: WebGLTexture, shape: [number, number],\n logicalTexType: TextureUsage, isPacked: boolean): void {\n if (this.freeTextures == null) {\n // Already disposed.\n return;\n }\n const physicalTexType =\n getPhysicalFromLogicalTextureType(logicalTexType, isPacked);\n const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);\n if (!(shapeKey in this.freeTextures)) {\n this.freeTextures[shapeKey] = [];\n }\n this.freeTextures[shapeKey].push(texture);\n this.numFreeTextures++;\n this.numUsedTextures--;\n const texList = this.usedTextures[shapeKey];\n const texIndex = texList.indexOf(texture);\n if (texIndex < 0) {\n throw new Error(\n 'Cannot release a texture that was never provided by this ' +\n 'texture manager');\n }\n texList.splice(texIndex, 1);\n this.log();\n }\n\n private log() {\n if (!this.logEnabled) {\n return;\n }\n const total = this.numFreeTextures + this.numUsedTextures;\n console.log(\n 'Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`,\n `(${total})`);\n }\n\n getNumUsedTextures(): number {\n return this.numUsedTextures;\n }\n\n getNumFreeTextures(): number {\n return this.numFreeTextures;\n }\n\n dispose() {\n if (this.freeTextures == null) {\n // Already disposed.\n return;\n }\n for (const texShape in this.freeTextures) {\n this.freeTextures[texShape].forEach(tex => {\n this.gpgpu.deleteMatrixTexture(tex);\n });\n }\n for (const texShape in this.usedTextures) {\n this.usedTextures[texShape].forEach(tex => {\n this.gpgpu.deleteMatrixTexture(tex);\n });\n }\n this.freeTextures = null;\n this.usedTextures = null;\n this.numUsedTextures = 0;\n this.numFreeTextures = 0;\n }\n}\n\nfunction getPhysicalTextureForRendering(isPacked: boolean):\n PhysicalTextureType {\n if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {\n if (isPacked) {\n return PhysicalTextureType.PACKED_2X2_FLOAT32;\n }\n return PhysicalTextureType.UNPACKED_FLOAT32;\n }\n\n if (isPacked) {\n return PhysicalTextureType.PACKED_2X2_FLOAT16;\n }\n return PhysicalTextureType.UNPACKED_FLOAT16;\n}\n\nfunction getPhysicalFromLogicalTextureType(\n logicalTexType: TextureUsage, isPacked: boolean): PhysicalTextureType {\n if (logicalTexType === TextureUsage.UPLOAD) {\n return PhysicalTextureType.PACKED_2X2_FLOAT32;\n } else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {\n return getPhysicalTextureForRendering(isPacked);\n } else if (\n logicalTexType === TextureUsage.DOWNLOAD ||\n logicalTexType === TextureUsage.PIXELS) {\n return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;\n }\n throw new Error(`Unknown logical texture type ${logicalTexType}`);\n}\n\nfunction getKeyFromTextureShape(\n shapeRowsCol: [number, number], physicalTexType: PhysicalTextureType,\n isPacked: boolean): string {\n return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class TileProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[];\n userCode: string;\n rank: number;\n\n constructor(aShape: number[], reps: number[]) {\n const outputShape: number[] = new Array(aShape.length);\n for (let i = 0; i < outputShape.length; i++) {\n outputShape[i] = aShape[i] * reps[i];\n }\n this.outputShape = outputShape;\n this.rank = outputShape.length;\n const dtype = getCoordsDataType(this.rank);\n const sourceCoords = getSourceCoords(aShape);\n\n this.userCode = `\n void main() {\n ${dtype} resRC = getOutputCoords();\n setOutput(getA(${sourceCoords}));\n }\n `;\n }\n}\n\nfunction getSourceCoords(aShape: number[]): string {\n const rank = aShape.length;\n if (rank > 5) {\n throw Error(`Tile for rank ${rank} is not yet supported`);\n }\n if (rank === 1) {\n return `imod(resRC, ${aShape[0]})`;\n }\n\n const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];\n\n const sourceCoords = [];\n for (let i = 0; i < aShape.length; i++) {\n sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);\n }\n return sourceCoords.join();\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nexport const ERF_P = 0.3275911;\nexport const ERF_A1 = 0.254829592;\nexport const ERF_A2 = -0.284496736;\nexport const ERF_A3 = 1.421413741;\nexport const ERF_A4 = -1.453152027;\nexport const ERF_A5 = 1.061405429;\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nexport const SELU_SCALEALPHA = 1.7580993408473768599402175208123;\nexport const SELU_SCALE = 1.0507009873554804934193349852946;\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as erf_util from '../../ops/erf_util';\nimport * as selu_util from '../../ops/selu_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class UnaryOpProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n\n constructor(aShape: number[], opSnippet: string) {\n this.outputShape = aShape;\n this.userCode = `\n float unaryOperation(float x) {\n ${opSnippet}\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n `;\n }\n}\n\nconst CHECK_NAN_SNIPPET = `if (isnan(x)) return x;`;\n\nexport const LINEAR = `return x;`;\n\nexport const ABS = `return abs(x);`;\n\nexport const RELU = CHECK_NAN_SNIPPET + `\n return (x < 0.0) ? 0.0 : x;\n`;\n\nexport const RELU6 = CHECK_NAN_SNIPPET + `\n return (x < 0.0) ? 0.0 : min(6.0, x);\n`;\n\nexport const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;\n\nexport const SELU = `\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = ${selu_util.SELU_SCALEALPHA};\n float scale = ${selu_util.SELU_SCALE};\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n`;\n\nexport function STEP(alpha = 0.0) {\n return CHECK_NAN_SNIPPET + `\n return x > 0.0 ? 1.0 : float(${alpha});\n `;\n}\n\nexport const NEG = `return -x;`;\n\nexport const CEIL = `return ceil(x);`;\n\nexport const FLOOR = `return floor(x);`;\n\nexport const SIGN = `\n if (isnan(x)) { return 0.0; }\n return sign(x);\n`;\n\nexport const IS_NAN = `return float(isnan(x));`;\n\nexport const IS_INF = `return float(isinf(x));`;\n\nexport const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;\n\nexport const ROUND = `\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n`;\n\nexport const EXP = `return exp(x);`;\n\nexport const EXPM1 = `return exp(x) - 1.0;`;\n\nexport const LOG = `if (x < 0.0) return NAN;\n return log(x);`;\n\nexport const LOG1P = `return log(1.0 + x);`;\n\nexport const SQRT = `return sqrt(x);`;\n\nexport const RSQRT = `return inversesqrt(x);`;\n\nexport const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;\n\n/**\n * mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX\n *\n * epsilon is the difference between 1.0 and the next representable\n * float. For a single precision 32 bit float this should be 2^-23, see:\n * https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm\n *\n * too_large = (x > -threshold) is value above which exp(x) may overflow\n * but softplus(x) == x is within machine epsilon\n *\n * too_small = (x < threshold) is value below which exp(x) may underflow,\n * but softplus(x) == exp(x) is within machine epsilon.\n */\nexport const SOFTPLUS = `\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n`;\n\nexport const SIN = CHECK_NAN_SNIPPET + `\n return sin(x);\n`;\n\nexport const COS = CHECK_NAN_SNIPPET + `\n return cos(x);\n`;\n\nexport const TAN = `return tan(x);`;\n\nexport const ASIN = CHECK_NAN_SNIPPET + `\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n`;\n\nexport const ACOS = CHECK_NAN_SNIPPET + `\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n`;\n\nexport const ATAN = CHECK_NAN_SNIPPET + `\n return atan(x);\n`;\n\nexport const SINH = `\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n`;\n\nexport const COSH = `\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n`;\n\nexport const TANH = `\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n`;\n\nexport const ASINH = CHECK_NAN_SNIPPET + `return log(x + sqrt(x * x + 1.0));`;\n\nexport const ACOSH = CHECK_NAN_SNIPPET + `\n if (x < 1.0) return NAN;\n return log(x + sqrt(x * x - 1.0));`;\n\nexport const ATANH = CHECK_NAN_SNIPPET + `\n if ((x < -1.0) || (x > 1.0)) return NAN;\n return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;\n\nexport const ERF = `\n // Error function is calculated approximately with elementary function.\n // See \"Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables\", Abramowitz and Stegun.\n float p = ${erf_util.ERF_P};\n float a1 = ${erf_util.ERF_A1};\n float a2 = ${erf_util.ERF_A2};\n float a3 = ${erf_util.ERF_A3};\n float a4 = ${erf_util.ERF_A4};\n float a5 = ${erf_util.ERF_A5};\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n`;\n\nexport const SQUARE = `return x * x;`;\n\nexport const RECIPROCAL = `return 1.0 / x;`;\n\nexport const LOGICAL_NOT = `return float(!(x >= 1.0));`;\n\nexport const TO_INT = `return float(int(x));`;\n\nexport const CLONE = 'return x;';\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport const LINEAR = `return x;`;\n\nexport const LOG = `\n vec4 result = log(x);\n vec4 isNaN = vec4(lessThan(x, vec4(0.0)));\n result.r = isNaN.r == 1.0 ? NAN : result.r;\n result.g = isNaN.g == 1.0 ? NAN : result.g;\n result.b = isNaN.b == 1.0 ? NAN : result.b;\n result.a = isNaN.a == 1.0 ? NAN : result.a;\n\n return result;\n`;\n\nexport const RELU = `\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n`;\n\nexport const RELU6 = `\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n`;\n\nexport const ELU = `\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n`;\n\nexport class UnaryOpPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n packedInputs = true;\n packedOutput = true;\n\n constructor(aShape: number[], opSnippet: string) {\n this.outputShape = aShape;\n this.userCode = `\n vec4 unaryOperation(vec4 x) {\n ${opSnippet}\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getChannels, getSourceCoords} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class UnpackProgram implements GPGPUProgram {\n variableNames = ['A'];\n packedInputs = true;\n packedOutput = false;\n outputShape: number[];\n userCode: string;\n\n constructor(outputShape: number[]) {\n this.outputShape = outputShape;\n const rank = outputShape.length;\n\n const channels = getChannels('rc', rank);\n const dtype = getCoordsDataType(rank);\n const sourceCoords = getSourceCoords(rank, channels);\n const innerDims = channels.slice(-2);\n const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;\n\n this.userCode = `\n void main() {\n ${dtype} rc = getOutputCoords();\n vec4 packedInput = getA(${sourceCoords});\n\n setOutput(getChannel(packedInput, ${coords}));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n// Import webgl flags.\nimport './flags_webgl';\n\nimport * as device_util from '../../device_util';\nimport {ENGINE, MemoryInfo, TimingInfo} from '../../engine';\nimport {env} from '../../environment';\nimport {tidy} from '../../globals';\nimport {TensorInfo} from '../../kernel_registry';\nimport {warn} from '../../log';\nimport {buffer} from '../../ops/array_ops';\nimport * as array_ops_util from '../../ops/array_ops_util';\nimport * as axis_util from '../../ops/axis_util';\nimport {complex, imag, real} from '../../ops/complex_ops';\nimport {computeOutShape} from '../../ops/concat_util';\nimport {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';\nimport {div} from '../../ops/div';\nimport {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';\nimport * as gather_nd_util from '../../ops/gather_nd_util';\nimport * as reduce_util from '../../ops/reduce_util';\nimport * as scatter_nd_util from '../../ops/scatter_nd_util';\nimport * as segment_util from '../../ops/segment_util';\nimport * as slice_util from '../../ops/slice_util';\nimport {softmax} from '../../ops/softmax';\nimport {range, scalar, tensor} from '../../ops/tensor_ops';\nimport {transpose} from '../../ops/transpose';\nimport {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';\nimport {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';\nimport * as util from '../../util';\nimport {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';\nimport {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';\nimport * as backend_util from '../backend_util';\nimport {mergeRealAndImagArrays} from '../complex_util';\nimport {nonMaxSuppressionV3} from '../non_max_suppression_impl';\nimport {split} from '../split_shared';\nimport {tile} from '../tile_impl';\nimport {topkImpl} from '../topk_impl';\nimport {whereImpl} from '../where_impl';\n\nimport {AddNProgram} from './addn_gpu';\nimport {AddNPackedProgram} from './addn_packed_gpu';\nimport {ArgMinMaxProgram} from './argminmax_gpu';\nimport {ArgMinMaxPackedProgram} from './argminmax_packed_gpu';\nimport {AvgPool2DBackpropProgram, AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu';\nimport {BatchNormProgram} from './batchnorm_gpu';\nimport {BatchNormPackedProgram} from './batchnorm_packed_gpu';\nimport * as binaryop_complex_gpu from './binaryop_complex_gpu';\nimport {BinaryOpComplexProgram} from './binaryop_complex_gpu';\nimport * as binaryop_gpu from './binaryop_gpu';\nimport {BinaryOpProgram} from './binaryop_gpu';\nimport * as binaryop_packed_gpu from './binaryop_packed_gpu';\nimport {BinaryOpPackedProgram} from './binaryop_packed_gpu';\nimport {getWebGLContext} from './canvas_util';\nimport {ClipProgram} from './clip_gpu';\nimport {ClipPackedProgram} from './clip_packed_gpu';\nimport {ComplexAbsProgram} from './complex_abs_gpu';\nimport {ConcatProgram} from './concat_gpu';\nimport {ConcatPackedProgram} from './concat_packed_gpu';\nimport {Conv2DDerFilterProgram, Conv2DDerInputProgram, Conv3DDerFilterProgram, Conv3DDerInputProgram} from './conv_backprop_gpu';\nimport {DepthwiseConv2DDerFilterProgram, DepthwiseConv2DDerInputProgram} from './conv_backprop_gpu_depthwise';\nimport {Conv2DProgram, Conv3DProgram} from './conv_gpu';\nimport {DepthwiseConv2DProgram} from './conv_gpu_depthwise';\nimport {DepthwiseConvPacked2DProgram} from './conv_packed_gpu_depthwise';\nimport {CropAndResizeProgram} from './crop_and_resize_gpu';\nimport {CumSumProgram} from './cumsum_gpu';\nimport {DecodeMatrixProgram} from './decode_matrix_gpu';\nimport {DecodeMatrixPackedProgram} from './decode_matrix_packed_gpu';\nimport {DepthToSpaceProgram} from './depth_to_space_gpu';\nimport {DiagProgram} from './diag_gpu';\nimport {EncodeFloatProgram} from './encode_float_gpu';\nimport {EncodeFloatPackedProgram} from './encode_float_packed_gpu';\nimport {EncodeMatrixProgram} from './encode_matrix_gpu';\nimport {EncodeMatrixPackedProgram} from './encode_matrix_packed_gpu';\nimport * as fft_gpu from './fft_gpu';\nimport {FFTProgram} from './fft_gpu';\nimport {FillProgram} from './fill_gpu';\nimport {GatherProgram} from './gather_gpu';\nimport {GatherNDProgram} from './gather_nd_gpu';\nimport {GPGPUContext} from './gpgpu_context';\nimport * as gpgpu_math from './gpgpu_math';\nimport {GPGPUBinary, GPGPUProgram, TensorData} from './gpgpu_math';\nimport {Im2ColPackedProgram} from './im2col_packed_gpu';\nimport {LRNProgram} from './lrn_gpu';\nimport {LRNGradProgram} from './lrn_grad_gpu';\nimport {LRNPackedProgram} from './lrn_packed_gpu';\nimport {MaxPool2DBackpropProgram, MaxPool3DBackpropProgram} from './max_pool_backprop_gpu';\nimport {MatMulPackedProgram} from './mulmat_packed_gpu';\nimport {MultinomialProgram} from './multinomial_gpu';\nimport {OneHotProgram} from './onehot_gpu';\nimport {PackProgram} from './pack_gpu';\nimport {PadProgram} from './pad_gpu';\nimport {PadPackedProgram} from './pad_packed_gpu';\nimport {Pool2DProgram, Pool3DProgram} from './pool_gpu';\nimport {ReduceProgram} from './reduce_gpu';\nimport {ReshapePackedProgram} from './reshape_packed_gpu';\nimport {ResizeBilinearBackpropProgram} from './resize_bilinear_backprop_gpu';\nimport {ResizeBilinearProgram} from './resize_bilinear_gpu';\nimport {ResizeBilinearPackedProgram} from './resize_bilinear_packed_gpu';\nimport {ResizeNearestNeigborBackpropProgram} from './resize_nearest_neighbor_backprop_gpu';\nimport {ResizeNearestNeighborProgram} from './resize_nearest_neighbor_gpu';\nimport {ReverseProgram} from './reverse_gpu';\nimport {ReversePackedProgram} from './reverse_packed_gpu';\nimport {ScatterProgram} from './scatter_gpu';\nimport {SegmentOpProgram} from './segment_gpu';\nimport {SelectProgram} from './select_gpu';\nimport {SliceProgram} from './slice_gpu';\nimport {SlicePackedProgram} from './slice_packed_gpu';\nimport {StridedSliceProgram} from './strided_slice_gpu';\nimport * as tex_util from './tex_util';\nimport {TextureData, TextureUsage} from './tex_util';\nimport {TextureManager} from './texture_manager';\nimport {TileProgram} from './tile_gpu';\nimport * as unary_op from './unaryop_gpu';\nimport {UnaryOpProgram} from './unaryop_gpu';\nimport * as unary_packed_op from './unaryop_packed_gpu';\nimport {UnaryOpPackedProgram} from './unaryop_packed_gpu';\nimport {UnpackProgram} from './unpack_gpu';\nimport * as webgl_util from './webgl_util';\n\ntype KernelInfo = {\n name: string; query: Promise;\n};\n\nexport type TimerNode = RecursiveArray|KernelInfo;\nexport interface CPUTimerQuery {\n startMs: number;\n endMs?: number;\n}\n\nexport interface WebGLMemoryInfo extends MemoryInfo {\n numBytesInGPU: number;\n unreliable: boolean;\n}\n\nexport interface WebGLTimingInfo extends TimingInfo {\n uploadWaitMs: number;\n downloadWaitMs: number;\n}\n\nconst binaryCaches: {[webGLVersion: string]: {[key: string]: GPGPUBinary}} = {};\n\nexport function getBinaryCache(webGLVersion: number) {\n if (webGLVersion in binaryCaches) {\n return binaryCaches[webGLVersion];\n }\n binaryCaches[webGLVersion] = {};\n return binaryCaches[webGLVersion];\n}\n\nfunction mapActivationToShaderProgram(\n activation: Activation, packed = false): string {\n if (activation === 'linear') {\n if (packed) {\n return unary_packed_op.LINEAR;\n }\n return unary_op.LINEAR;\n } else if (activation === 'relu') {\n if (packed) {\n return unary_packed_op.RELU;\n }\n return unary_op.RELU;\n } else if (activation === 'elu') {\n if (packed) {\n return unary_packed_op.ELU;\n }\n return unary_op.ELU;\n } else if (activation === 'relu6') {\n if (packed) {\n return unary_packed_op.RELU6;\n }\n return unary_op.RELU6;\n } else if (activation === 'prelu') {\n if (packed) {\n return binaryop_packed_gpu.PRELU;\n }\n return binaryop_gpu.PRELU;\n }\n throw new Error(`Activation ${\n activation} has not been implemented for the WebGL backend.`);\n}\n\n// Empirically determined constant used to determine size threshold for handing\n// off execution to the CPU.\nconst CPU_HANDOFF_SIZE_THRESHOLD = 128;\n\n// Empirically determined constant used to decide the number of MB on GPU\n// before we warn about high memory use. The MB are this constant * screen area\n// * dpi / 1024 / 1024.\nconst BEFORE_PAGING_CONSTANT = 600;\nfunction numMBBeforeWarning(): number {\n if (env().global.screen == null) {\n return 1024; // 1 GB.\n }\n return (env().global.screen.height * env().global.screen.width *\n window.devicePixelRatio) *\n BEFORE_PAGING_CONSTANT / 1024 / 1024;\n}\n\n// Empirically determined minimal shared dimension in matmul before we forward\n// to a.mul(b).sum() in order to take advantage of GPU parallelism. See\n// https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.\nexport const MATMUL_SHARED_DIM_THRESHOLD = 1000;\n\nexport class MathBackendWebGL extends KernelBackend {\n texData: DataStorage;\n gpgpu: GPGPUContext;\n\n // Maps data ids that have a pending read operation, to list of subscribers.\n private pendingRead = new WeakMap void>>();\n // List of data ids that are scheduled for disposal, but are waiting on a\n // pending read operation.\n private pendingDisposal = new WeakSet();\n // Used to count the number of 'shallow' sliced tensors that point to the\n // same data id.\n private dataRefCount = new WeakMap();\n private numBytesInGPU = 0;\n\n private canvas: HTMLCanvasElement|OffscreenCanvas;\n\n private programTimersStack: TimerNode[];\n private activeTimers: TimerNode[];\n // Accumulated time spent (including blocking) in uploading data to webgl.\n private uploadWaitMs = 0;\n // Accumulated time spent (including blocking in downloading data from webgl.\n private downloadWaitMs = 0;\n private cpuBackend: KernelBackend;\n\n // Number of bits of precision of this backend.\n private floatPrecisionValue: 32|16;\n\n private textureManager: TextureManager;\n private binaryCache: {[key: string]: GPGPUBinary};\n private gpgpuCreatedLocally: boolean;\n private numMBBeforeWarning: number;\n private warnedAboutMemory = false;\n\n constructor(gpgpu?: GPGPUContext) {\n super();\n if (!env().getBool('HAS_WEBGL')) {\n throw new Error('WebGL is not supported on this device');\n }\n\n if (gpgpu == null) {\n const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));\n this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));\n this.gpgpu = new GPGPUContext(gl);\n this.canvas = gl.canvas;\n this.gpgpuCreatedLocally = true;\n } else {\n this.gpgpu = gpgpu;\n this.binaryCache = {};\n this.gpgpuCreatedLocally = false;\n this.canvas = gpgpu.gl.canvas;\n }\n this.textureManager = new TextureManager(this.gpgpu);\n this.numMBBeforeWarning = numMBBeforeWarning();\n\n this.texData = new DataStorage(this, ENGINE);\n }\n\n numDataIds() {\n return this.texData.numDataIds() +\n (this.cpuBackend ? this.cpuBackend.numDataIds() : 0) -\n this.pendingDeletes;\n }\n\n write(values: BackendValues, shape: number[], dtype: DataType): DataId {\n if (env().getBool('DEBUG')) {\n this.checkNumericalProblems(values);\n }\n if (dtype === 'complex64' && values != null) {\n throw new Error(\n `Cannot write to a complex64 dtype. ` +\n `Please use tf.complex(real, imag).`);\n }\n const dataId = {};\n this.texData.set(\n dataId, {shape, dtype, values, usage: TextureUsage.UPLOAD});\n return dataId;\n }\n\n move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):\n void {\n if (env().getBool('DEBUG')) {\n this.checkNumericalProblems(values);\n }\n if (dtype === 'complex64') {\n throw new Error(\n `Cannot write to a complex64 dtype. ` +\n `Please use tf.complex(real, imag).`);\n }\n this.texData.set(\n dataId, {shape, dtype, values, usage: TextureUsage.UPLOAD});\n }\n\n readSync(dataId: DataId): BackendValues {\n const texData = this.texData.get(dataId);\n const {values, dtype, complexTensors, slice, shape, isPacked} = texData;\n if (slice != null) {\n let program;\n if (isPacked) {\n program = new UnaryOpPackedProgram(shape, unary_op.CLONE);\n } else {\n program = new UnaryOpProgram(shape, unary_op.CLONE);\n }\n const res =\n this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);\n const data = this.readSync(res.dataId);\n this.disposeData(res.dataId);\n return data;\n }\n if (values != null) {\n return this.convertAndCacheOnCPU(dataId);\n }\n if (dtype === 'string') {\n return values;\n }\n const shouldTimeProgram = this.activeTimers != null;\n let start: number;\n if (shouldTimeProgram) {\n start = util.now();\n }\n\n let result: Float32Array;\n if (dtype === 'complex64') {\n const realValues = complexTensors.real.dataSync() as Float32Array;\n const imagValues = complexTensors.imag.dataSync() as Float32Array;\n result = mergeRealAndImagArrays(realValues, imagValues);\n } else {\n result = this.getValuesFromTexture(dataId);\n }\n\n if (shouldTimeProgram) {\n this.downloadWaitMs += util.now() - start;\n }\n return this.convertAndCacheOnCPU(dataId, result);\n }\n\n async read(dataId: DataId): Promise {\n if (this.pendingRead.has(dataId)) {\n const subscribers = this.pendingRead.get(dataId);\n return new Promise(resolve => subscribers.push(resolve));\n }\n const texData = this.texData.get(dataId);\n const {values, shape, slice, dtype, complexTensors, isPacked} = texData;\n\n if (slice != null) {\n let program;\n if (isPacked) {\n program = new UnaryOpPackedProgram(shape, unary_op.CLONE);\n } else {\n program = new UnaryOpProgram(shape, unary_op.CLONE);\n }\n const res =\n this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);\n const data = this.read(res.dataId);\n this.disposeData(res.dataId);\n return data;\n }\n\n if (values != null) {\n return this.convertAndCacheOnCPU(dataId);\n }\n\n if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&\n env().getNumber('WEBGL_VERSION') === 2) {\n throw new Error(\n `tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +\n `WEBGL_VERSION=2 not yet supported.`);\n }\n\n let buffer = null;\n let tmpDownloadTarget: TensorInfo;\n\n if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {\n // Possibly copy the texture into a buffer before inserting a fence.\n tmpDownloadTarget = this.decode(dataId);\n const tmpData = this.texData.get(tmpDownloadTarget.dataId);\n\n buffer = this.gpgpu.createBufferFromTexture(\n tmpData.texture, ...tex_util.getDenseTexShape(shape));\n }\n\n this.pendingRead.set(dataId, []);\n\n if (dtype !== 'complex64') {\n // Create a fence and wait for it to resolve.\n await this.gpgpu.createAndWaitForFence();\n }\n\n // Download the values from the GPU.\n let vals: Float32Array;\n if (dtype === 'complex64') {\n const ps = await Promise.all(\n [complexTensors.real.data(), complexTensors.imag.data()]);\n const realValues = ps[0];\n const imagValues = ps[1];\n vals = mergeRealAndImagArrays(\n realValues as Float32Array, imagValues as Float32Array);\n } else if (buffer == null) {\n vals = this.getValuesFromTexture(dataId);\n } else {\n const size = util.sizeFromShape(shape);\n vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);\n }\n if (tmpDownloadTarget != null) {\n this.disposeData(tmpDownloadTarget.dataId);\n }\n const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);\n\n const subscribers = this.pendingRead.get(dataId);\n this.pendingRead.delete(dataId);\n\n // Notify all pending reads.\n subscribers.forEach(resolve => resolve(dTypeVals));\n if (this.pendingDisposal.has(dataId)) {\n this.pendingDisposal.delete(dataId);\n this.disposeData(dataId);\n this.pendingDeletes--;\n }\n return dTypeVals;\n }\n\n private checkNumericalProblems(values: BackendValues): void {\n if (values == null) {\n return;\n }\n for (let i = 0; i < values.length; i++) {\n const num = values[i] as number;\n if (!webgl_util.canBeRepresented(num)) {\n if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {\n throw Error(\n `The value ${num} cannot be represented with your ` +\n `current settings. Consider enabling float32 rendering: ` +\n `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);\n }\n throw Error(`The value ${num} cannot be represented on this device.`);\n }\n }\n }\n\n private getValuesFromTexture(dataId: DataId): Float32Array {\n const {shape, dtype, isPacked} = this.texData.get(dataId);\n const size = util.sizeFromShape(shape);\n if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {\n const tmpTarget = this.decode(dataId);\n const tmpData = this.texData.get(tmpTarget.dataId);\n const vals = this.gpgpu\n .downloadMatrixFromPackedTexture(\n tmpData.texture, ...tex_util.getDenseTexShape(shape))\n .subarray(0, size);\n\n this.disposeData(tmpTarget.dataId);\n\n return vals;\n }\n\n const shouldUsePackedProgram =\n env().getBool('WEBGL_PACK') && isPacked === true;\n const outputShape =\n shouldUsePackedProgram ? webgl_util.getShapeAs3D(shape) : shape;\n const program = shouldUsePackedProgram ?\n new EncodeFloatPackedProgram(outputShape as [number, number, number]) :\n new EncodeFloatProgram(outputShape);\n const output = this.runWebGLProgram(\n program, [{shape: outputShape, dtype, dataId}], 'float32');\n const tmpData = this.texData.get(output.dataId);\n const vals =\n this.gpgpu\n .downloadByteEncodedFloatMatrixFromOutputTexture(\n tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])\n .subarray(0, size);\n this.disposeData(output.dataId);\n\n return vals;\n }\n\n async time(f: () => void): Promise {\n const oldActiveTimers = this.activeTimers;\n const newActiveTimers: TimerNode[] = [];\n\n let outerMostTime = false;\n if (this.programTimersStack == null) {\n this.programTimersStack = newActiveTimers;\n outerMostTime = true;\n } else {\n this.activeTimers.push(newActiveTimers);\n }\n this.activeTimers = newActiveTimers;\n\n f();\n\n // needing to split these up because util.flatten only accepts certain types\n const flattenedActiveTimerQueries =\n util.flatten(this.activeTimers.map((d: KernelInfo) => d.query))\n .filter(d => d != null);\n const flattenedActiveTimerNames =\n util.flatten(this.activeTimers.map((d: KernelInfo) => d.name))\n .filter(d => d != null);\n\n this.activeTimers = oldActiveTimers;\n\n if (outerMostTime) {\n this.programTimersStack = null;\n }\n\n const res: WebGLTimingInfo = {\n uploadWaitMs: this.uploadWaitMs,\n downloadWaitMs: this.downloadWaitMs,\n kernelMs: null,\n wallMs: null // will be filled by the engine\n };\n\n if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {\n const kernelMs = await Promise.all(flattenedActiveTimerQueries);\n\n res['kernelMs'] = util.sum(kernelMs);\n res['getExtraProfileInfo'] = () =>\n kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d}))\n .map(d => `${d.name}: ${d.ms}`)\n .join(', ');\n } else {\n res['kernelMs'] = {\n error: 'WebGL query timers are not supported in this environment.'\n };\n }\n\n this.uploadWaitMs = 0;\n this.downloadWaitMs = 0;\n return res;\n }\n memory(): WebGLMemoryInfo {\n return {unreliable: false, numBytesInGPU: this.numBytesInGPU} as\n WebGLMemoryInfo;\n }\n\n private startTimer(): WebGLQuery|CPUTimerQuery {\n if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {\n return this.gpgpu.beginQuery();\n }\n return {startMs: util.now(), endMs: null};\n }\n\n private endTimer(query: WebGLQuery|CPUTimerQuery): WebGLQuery|CPUTimerQuery {\n if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {\n this.gpgpu.endQuery();\n return query;\n }\n (query as CPUTimerQuery).endMs = util.now();\n return query;\n }\n\n private async getQueryTime(query: WebGLQuery|CPUTimerQuery): Promise {\n if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {\n return this.gpgpu.waitForQueryAndGetTime(query as WebGLQuery);\n }\n const timerQuery = query as CPUTimerQuery;\n return timerQuery.endMs - timerQuery.startMs;\n }\n\n private pendingDeletes = 0;\n\n disposeData(dataId: DataId): void {\n if (this.pendingDisposal.has(dataId)) {\n return;\n }\n if (this.pendingRead.has(dataId)) {\n this.pendingDisposal.add(dataId);\n this.pendingDeletes++;\n return;\n }\n // No-op if already disposed.\n if (!this.texData.has(dataId)) {\n return;\n }\n\n this.releaseGPUData(dataId);\n const {complexTensors} = this.texData.get(dataId);\n if (complexTensors != null) {\n complexTensors.real.dispose();\n complexTensors.imag.dispose();\n }\n this.texData.delete(dataId);\n }\n\n private releaseGPUData(dataId: DataId): void {\n const {texture, dtype, texShape, usage, isPacked, slice} =\n this.texData.get(dataId);\n const key = slice && slice.origDataId || dataId;\n const refCount = this.dataRefCount.get(key);\n if (refCount > 1) {\n this.dataRefCount.set(key, refCount - 1);\n } else {\n this.dataRefCount.delete(key);\n if (texture != null) {\n this.numBytesInGPU -= this.computeBytes(texShape, dtype);\n this.textureManager.releaseTexture(texture, texShape, usage, isPacked);\n }\n }\n const texData = this.texData.get(dataId);\n texData.texture = null;\n texData.texShape = null;\n texData.isPacked = false;\n texData.slice = null;\n }\n\n getTexture(dataId: DataId): WebGLTexture {\n this.uploadToGPU(dataId);\n return this.texData.get(dataId).texture;\n }\n\n /**\n * Returns internal information for the specific data bucket. Used in unit\n * tests.\n */\n getDataInfo(dataId: DataId): TextureData {\n return this.texData.get(dataId);\n }\n\n private getCPUBackend(): KernelBackend|null {\n if (!env().getBool('WEBGL_CPU_FORWARD')) {\n return null;\n }\n\n if (this.cpuBackend == null) {\n this.cpuBackend = ENGINE.findBackend('cpu');\n }\n\n return this.cpuBackend;\n }\n\n /*\n Tests whether all the inputs to an op are small and on the CPU. This heuristic\n determines when it would be faster to execute a kernel on the CPU. WebGL\n kernels opt into running this check and forwarding when appropriate.\n TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more\n sustainable strategy for optimizing backend execution of ops.\n */\n shouldExecuteOnCPU(\n inputs: TensorInfo[],\n sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean {\n return this.getCPUBackend() != null &&\n inputs.every(\n input => this.texData.get(input.dataId).texture == null &&\n util.sizeFromShape(input.shape) < sizeThreshold);\n }\n\n getGPGPUContext(): GPGPUContext {\n return this.gpgpu;\n }\n\n complex(real: T, imag: T): T {\n const result = this.makeOutput(real.shape, 'complex64');\n const resultData = this.texData.get(result.dataId);\n // The backend owns the reference to the underlying real and imaginary\n // clones. These will explicitly get disposed when the complex tensor is\n // disposed.\n resultData.complexTensors = {\n real: ENGINE.keep(real.clone()),\n imag: ENGINE.keep(imag.clone())\n };\n\n return result as T;\n }\n real(input: T): T {\n const resultData = this.texData.get(input.dataId);\n return resultData.complexTensors.real.clone() as T;\n }\n imag(input: T): T {\n const resultData = this.texData.get(input.dataId);\n return resultData.complexTensors.imag.clone() as T;\n }\n\n slice(x: T, begin: number[], size: number[]): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.slice(x, begin, size);\n }\n // Short-circuit computation if the slice is zero-sized.\n if (util.sizeFromShape(size) === 0) {\n return tensor([], size, x.dtype) as T;\n }\n const {isPacked} = this.texData.get(x.dataId);\n const isContinous = slice_util.isSliceContinous(x.shape, begin, size);\n if (isPacked || !isContinous) {\n const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?\n new SlicePackedProgram(size) :\n new SliceProgram(size);\n const customSetup = program.getCustomSetupFunc(begin);\n return this.compileAndRun(program, [x], null, customSetup);\n }\n this.uploadToGPU(x.dataId);\n return this.shallowSlice(x, begin, size) as T;\n }\n\n private shallowSlice(x: Tensor, begin: number[], size: number[]): Tensor {\n const xTexData = this.texData.get(x.dataId);\n const t = this.makeOutput(size, x.dtype);\n const newTexData = this.texData.get(t.dataId);\n // Copy texture data from the original tensor.\n Object.assign(newTexData, xTexData);\n newTexData.shape = size;\n newTexData.dtype = x.dtype;\n let flatOffset = slice_util.computeFlatOffset(begin, x.strides);\n if (xTexData.slice) {\n // We are slicing an already sliced tensor, so we have to accumulate\n // the offset.\n flatOffset += xTexData.slice.flatOffset;\n }\n newTexData.slice = {\n flatOffset,\n // Point to the original dataId, which is used to do ref counting.\n origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId\n };\n\n // Increase the ref count for that data bucket.\n const refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1;\n this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);\n\n return t;\n }\n\n stridedSlice(\n x: T, begin: number[], end: number[], strides: number[]): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.stridedSlice(x, begin, end, strides);\n }\n\n const outShape = slice_util.computeOutShape(begin, end, strides);\n\n if (outShape.some(axis => axis === 0)) {\n return tensor([], outShape) as T;\n }\n\n const program = new StridedSliceProgram(begin, strides, outShape);\n return this.compileAndRun(program, [x]);\n }\n\n reverse(x: T, axis: number[]): T {\n const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?\n new ReversePackedProgram(x.shape, axis) :\n new ReverseProgram(x.shape, axis);\n return this.compileAndRun(program, [x]);\n }\n\n concat(tensors: Tensor[], axis: number): Tensor {\n if (tensors[0].dtype === 'complex64') {\n const reals = tensors.map((t) => real(t));\n const imags = tensors.map((t) => imag(t));\n return complex(this.concat(reals, axis), this.concat(imags, axis));\n }\n if (this.shouldExecuteOnCPU(tensors)) {\n return this.cpuBackend.concat(tensors, axis);\n }\n\n if (tensors.length === 1) {\n return tensors[0];\n }\n if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {\n const midIndex = Math.floor(tensors.length / 2);\n const leftSide = this.concat(tensors.slice(0, midIndex), axis);\n const rightSide = this.concat(tensors.slice(midIndex), axis);\n return this.concat([leftSide, rightSide], axis);\n }\n if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && tensors[0].rank > 1) {\n const program = new ConcatPackedProgram(tensors.map(t => t.shape), axis);\n return this.compileAndRun(program, tensors);\n }\n // Any concat of n-dimensional tensors across any axis can be reduced to\n // a concatenation of two-dimensional tensors across the axis 1 by first\n // partitioning the axes of the original tensors into those less than the\n // axis to be concatenated and the rest. Then reshape the tensors\n // into a two-dimensional tensor by collapsing these two sets of axes and\n // concatenate the resulting matrices across the axis 1, finally reshaping\n // the result to have the proper shape.\n const outShape = computeOutShape(tensors.map(t => t.shape), axis);\n const tensors2D =\n tensors.map(t => t.as2D(-1, sizeFromShape(t.shape.slice(axis))));\n const program = new ConcatProgram(tensors2D.map(t => t.shape));\n const res: Tensor = this.compileAndRun(program, tensors2D);\n return res.reshape(outShape);\n }\n\n neg(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.neg(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_op.NEG, x.dtype) as T;\n }\n const program = new UnaryOpProgram(x.shape, unary_op.NEG);\n return this.compileAndRun(program, [x]);\n }\n\n batchMatMul(\n a: Tensor3D, b: Tensor3D, transposeA: boolean,\n transposeB: boolean): Tensor3D {\n const outerShapeA = transposeA ? a.shape[2] : a.shape[1];\n const outerShapeB = transposeB ? b.shape[1] : b.shape[2];\n const sharedDim = transposeA ? a.shape[1] : a.shape[2];\n const [batch, , ] = a.shape;\n\n // Since the matrices are vectors, it is faster to call mul().sum()\n // because sum() is O(sqrt(N)) due to divide-and-conquer.\n if ((outerShapeA === 1 || outerShapeB === 1) &&\n sharedDim > MATMUL_SHARED_DIM_THRESHOLD) {\n if (transposeA) {\n a = transpose(a, [0, 2, 1]);\n }\n if (transposeB) {\n b = transpose(b, [0, 2, 1]);\n }\n\n const a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);\n const axis = outerShapeB === 1 ? 2 : 1;\n const b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b;\n return this.multiply(a3D, b3D).sum(axis, true /* keepDims */);\n }\n\n const dtype = upcastType(a.dtype, b.dtype);\n\n const program = new MatMulPackedProgram(\n a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB);\n return this.compileAndRun(program, [a, b], dtype);\n }\n\n fusedBatchMatMul(\n {a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:\n FusedBatchMatMulConfig): Tensor3D {\n const outerShapeA = transposeA ? a.shape[2] : a.shape[1];\n const outerShapeB = transposeB ? b.shape[1] : b.shape[2];\n const [batch, , ] = a.shape;\n\n const dtype = upcastType(a.dtype, b.dtype);\n\n const hasBias = bias != null;\n const hasPreluActivationWeights = preluActivationWeights != null;\n const fusedActivation =\n activation ? mapActivationToShaderProgram(activation, true) : null;\n const program = new MatMulPackedProgram(\n a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,\n hasBias, fusedActivation, hasPreluActivationWeights);\n const inputs: TensorInfo[] = [a, b];\n if (bias) {\n inputs.push(bias);\n }\n if (preluActivationWeights) {\n inputs.push(preluActivationWeights);\n }\n return this.compileAndRun(program, inputs, dtype);\n }\n\n multiply(a: Tensor, b: Tensor): Tensor {\n if (a.dtype === 'complex64') {\n const aData = this.texData.get(a.dataId);\n const bData = this.texData.get(b.dataId);\n\n const realProgram = new BinaryOpComplexProgram(\n binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape);\n const imagProgram = new BinaryOpComplexProgram(\n binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);\n\n const inputs = [\n this.makeComplexComponentTensorInfo(a, aData.complexTensors.real),\n this.makeComplexComponentTensorInfo(a, aData.complexTensors.imag),\n this.makeComplexComponentTensorInfo(b, bData.complexTensors.real),\n this.makeComplexComponentTensorInfo(b, bData.complexTensors.imag)\n ];\n const real = this.compileAndRun(realProgram, inputs);\n const imag = this.compileAndRun(imagProgram, inputs);\n\n const complex = this.complex(real, imag);\n real.dispose();\n imag.dispose();\n return complex;\n }\n\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.multiply(a, b);\n }\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype);\n }\n const program = new BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], a.dtype);\n }\n\n batchNormalization(\n x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D,\n varianceEpsilon: number, scale?: Tensor4D|Tensor1D,\n offset?: Tensor4D|Tensor1D): Tensor4D {\n const inputs = [x, mean, variance];\n\n let offsetShape = null;\n if (offset != null) {\n offsetShape = offset.shape;\n inputs.push(offset);\n }\n\n let scaleShape = null;\n if (scale != null) {\n scaleShape = scale.shape;\n inputs.push(scale);\n }\n\n if (env().getBool('WEBGL_PACK_NORMALIZATION')) {\n const batchNormPackedProgram = new BatchNormPackedProgram(\n x.shape, mean.shape, variance.shape, offsetShape, scaleShape,\n varianceEpsilon);\n return this.compileAndRun(batchNormPackedProgram, inputs);\n }\n\n const batchNormProgram = new BatchNormProgram(\n x.shape, mean.shape, variance.shape, offsetShape, scaleShape,\n varianceEpsilon);\n return this.compileAndRun(batchNormProgram, inputs);\n }\n\n localResponseNormalization4D(\n x: Tensor4D, radius: number, bias: number, alpha: number,\n beta: number): Tensor4D {\n const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?\n new LRNPackedProgram(x.shape, radius, bias, alpha, beta) :\n new LRNProgram(x.shape, radius, bias, alpha, beta);\n return this.compileAndRun(program, [x]);\n }\n\n LRNGrad(\n dy: Tensor4D, inputImage: Tensor4D, outputImage: Tensor4D,\n depthRadius: number, bias: number, alpha: number,\n beta: number): Tensor4D {\n const program =\n new LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta);\n return this.compileAndRun(program, [inputImage, outputImage, dy]);\n }\n\n tile(x: T, reps: number[]): T {\n if (x.dtype === 'string') {\n const data = this.readSync(x.dataId) as Uint8Array[];\n const decodedData = data.map(d => util.decodeString(d));\n const buf = buffer(x.shape, x.dtype, decodedData);\n return tile(buf, reps) as T;\n }\n const program = new TileProgram(x.shape, reps);\n return this.compileAndRun(program, [x]);\n }\n\n pad(\n x: T, paddings: Array<[number, number]>, constantValue: number): T {\n const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?\n new PadPackedProgram(x.shape, paddings, constantValue) :\n new PadProgram(x.shape, paddings, constantValue);\n return this.compileAndRun(program, [x]);\n }\n\n gather(x: T, indices: Tensor1D, axis: number): T {\n if (this.shouldExecuteOnCPU([x, indices])) {\n return this.cpuBackend.gather(x, indices, axis);\n }\n const program = new GatherProgram(x.shape, indices.size, axis);\n return this.compileAndRun(program, [x, indices]);\n }\n\n batchToSpaceND(\n x: T, blockShape: number[], crops: number[][]): T {\n util.assert(\n x.rank <= 4,\n () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +\n 'implemented yet');\n const prod = blockShape.reduce((a, b) => a * b);\n\n const reshaped = array_ops_util.getReshaped(x.shape, blockShape, prod);\n const permuted =\n array_ops_util.getPermuted(reshaped.length, blockShape.length);\n const reshapedPermuted =\n array_ops_util.getReshapedPermuted(x.shape, blockShape, prod);\n const sliceBeginCoords =\n array_ops_util.getSliceBeginCoords(crops, blockShape.length);\n const sliceSize =\n array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);\n\n return transpose(x.reshape(reshaped), permuted)\n .reshape(reshapedPermuted)\n .slice(sliceBeginCoords, sliceSize) as T;\n }\n\n spaceToBatchND(\n x: T, blockShape: number[], paddings: Array<[number, number]>): T {\n util.assert(\n x.rank <= 4,\n () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +\n 'implemented yet');\n\n const prod = blockShape.reduce((a, b) => a * b);\n\n const completePaddings: Array<[number, number]> = [[0, 0]];\n completePaddings.push(...paddings);\n for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {\n completePaddings.push([0, 0]);\n }\n\n const paddedX = x.pad(completePaddings);\n\n const reshapedPaddedShape =\n array_ops_util.getReshaped(paddedX.shape, blockShape, prod, false);\n\n const permutedReshapedPaddedPermutation = array_ops_util.getPermuted(\n reshapedPaddedShape.length, blockShape.length, false);\n\n const flattenShape = array_ops_util.getReshapedPermuted(\n paddedX.shape, blockShape, prod, false);\n\n return transpose(\n paddedX.reshape(reshapedPaddedShape),\n permutedReshapedPaddedPermutation)\n .reshape(flattenShape) as T;\n }\n\n private reduce(\n x: Tensor2D, reduceType: 'all'|'any'|'max'|'min'|'sum'|'prod',\n dtype: DataType): Tensor2D {\n const batchSize = x.shape[0];\n const inSize = x.shape[1];\n const windowSize = reduce_util.computeOptimalWindowSize(inSize);\n const reduceInfo = {windowSize, inSize, batchSize};\n const program = new ReduceProgram(reduceInfo, reduceType);\n const output = this.compileAndRun(program, [x], dtype);\n // No need to run another GPGPU program.\n if (output.shape[1] === 1) {\n return output;\n }\n return this.reduce(output, reduceType, dtype);\n }\n\n private argReduce(\n x: Tensor2D, reduceType: 'max'|'min',\n bestIndicesA: Tensor2D = null): Tensor2D {\n let batchSize = x.shape[0];\n let inSize = x.shape[1];\n if (bestIndicesA != null) {\n batchSize = bestIndicesA.shape[0];\n inSize = bestIndicesA.shape[1];\n }\n const windowSize = reduce_util.computeOptimalWindowSize(inSize);\n const reduceInfo = {windowSize, inSize, batchSize};\n const program =\n new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);\n const inputs = [x];\n if (bestIndicesA != null) {\n inputs.push(bestIndicesA);\n }\n const output = this.compileAndRun(program, inputs, 'int32');\n // No need to run another GPGPU program.\n if (output.shape[1] === 1) {\n return output;\n }\n return this.argReduce(x, reduceType, output);\n }\n\n private argReducePacked(\n x: Tensor, reduceType: 'max'|'min', bestIndicesA: Tensor = null): Tensor {\n const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;\n const inSize = inShape[inShape.length - 1];\n const windowSize = reduce_util.computeOptimalWindowSize(inSize);\n const program = new ArgMinMaxPackedProgram(\n inShape, windowSize, reduceType, bestIndicesA == null);\n const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];\n const output = this.compileAndRun(program, inputs, 'int32');\n if (output.rank === x.rank) {\n return this.argReducePacked(x, reduceType, output);\n }\n return output;\n }\n\n sum(x: Tensor, axes: number[]): Tensor {\n axis_util.assertAxesAreInnerMostDims('sum', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n const outputDType = sumOutType(x.dtype);\n return this.reduce(a2D, 'sum', outputDType).reshape(outShape);\n }\n\n prod(x: Tensor, axes: number[]): Tensor {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.prod(x, axes);\n }\n\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n const outputDType = sumOutType(x.dtype);\n return this.reduce(a2D, 'prod', outputDType).reshape(outShape);\n }\n\n unsortedSegmentSum(\n x: T, segmentIds: Tensor1D, numSegments: number): Tensor {\n let axis = 0;\n const permutation = axis_util.getAxesPermutation([axis], x.rank);\n let permutedX = x;\n if (permutation != null) {\n permutedX = transpose(x, permutation);\n axis = axis_util.getInnerMostAxes(1, x.rank)[0];\n }\n\n const outShape =\n segment_util.computeOutShape(permutedX.shape, axis, numSegments);\n const inSize = util.sizeFromShape([permutedX.shape[axis]]);\n const a2D = permutedX.as2D(-1, inSize);\n const outputDType = sumOutType(x.dtype);\n let result =\n this.segOpCompute(\n a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)\n .reshape(outShape);\n if (permutation != null) {\n result = transpose(result, axis_util.getUndoAxesPermutation(permutation));\n }\n return result;\n }\n\n private segOpCompute(\n x: Tensor2D, segOpType: 'unsortedSegmentSum', segmentIds: Tensor1D,\n dtype: DataType, numSegments: number): Tensor2D {\n const batchSize = x.shape[0];\n const inSize = x.shape[1];\n const windowSize =\n segment_util.segOpComputeOptimalWindowSize(inSize, numSegments);\n const segOpInfo = {windowSize, inSize, batchSize, numSegments};\n const program = new SegmentOpProgram(segOpInfo, segOpType);\n const output =\n this.compileAndRun(program, [x, segmentIds], dtype);\n // No need to run another GPGPU program.\n if (output.shape[1] === numSegments) {\n return output;\n }\n segmentIds = range(0, numSegments).tile([inSize / windowSize]);\n return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments);\n }\n\n private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'):\n Tensor {\n const axes = [axis];\n axis_util.assertAxesAreInnerMostDims(\n 'arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes,\n x.rank);\n if (!env().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) {\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n return this.argReduce(a2D, reduceType).reshape(outShape);\n }\n return this.argReducePacked(x, reduceType);\n }\n\n argMin(x: Tensor, axis: number): Tensor {\n return this.argMinMaxReduce(x, axis, 'min');\n }\n\n argMax(x: Tensor, axis: number): Tensor {\n return this.argMinMaxReduce(x, axis, 'max');\n }\n\n cumsum(x: Tensor, axis: number, exclusive: boolean, reverse: boolean):\n Tensor {\n if (axis !== x.rank - 1) {\n throw new Error(\n `WebGL cumsum shader expects an inner-most axis=${x.rank - 1} ` +\n `but got axis=${axis}`);\n }\n const program = new CumSumProgram(x.shape, exclusive, reverse);\n return this.compileAndRun(program, [x]);\n }\n\n equal(a: Tensor, b: Tensor): Tensor {\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.EQUAL, 'bool');\n }\n const program = new BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n notEqual(a: Tensor, b: Tensor): Tensor {\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.NOT_EQUAL, 'bool');\n }\n const program =\n new BinaryOpProgram(binaryop_gpu.NOT_EQUAL, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n less(a: Tensor, b: Tensor): Tensor {\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.less(a, b);\n }\n\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS, 'bool');\n }\n\n const program = new BinaryOpProgram(binaryop_gpu.LESS, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n lessEqual(a: Tensor, b: Tensor): Tensor {\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS_EQUAL, 'bool');\n }\n const program =\n new BinaryOpProgram(binaryop_gpu.LESS_EQUAL, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n greater(a: Tensor, b: Tensor): Tensor {\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.greater(a, b);\n }\n\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.GREATER, 'bool');\n }\n\n const program = new BinaryOpProgram(binaryop_gpu.GREATER, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n greaterEqual(a: Tensor, b: Tensor): Tensor {\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(\n a, b, binaryop_packed_gpu.GREATER_EQUAL, 'bool');\n }\n const program =\n new BinaryOpProgram(binaryop_gpu.GREATER_EQUAL, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n logicalNot(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.LOGICAL_NOT);\n return this.compileAndRun(program, [x]);\n }\n\n logicalAnd(a: Tensor, b: Tensor): Tensor {\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_AND, 'bool');\n }\n const program =\n new BinaryOpProgram(binaryop_gpu.LOGICAL_AND, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n logicalOr(a: Tensor, b: Tensor): Tensor {\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_OR, 'bool');\n }\n const program =\n new BinaryOpProgram(binaryop_gpu.LOGICAL_OR, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], 'bool');\n }\n\n select(condition: Tensor, a: Tensor, b: Tensor): Tensor {\n const program = new SelectProgram(condition.rank, a.shape, a.rank);\n return this.compileAndRun(\n program, [condition, a, b], upcastType(a.dtype, b.dtype));\n }\n\n where(condition: Tensor): Tensor2D {\n warn(\n 'tf.where() in webgl locks the UI thread. ' +\n 'Call tf.whereAsync() instead');\n const condVals = condition.dataSync();\n return whereImpl(condition.shape, condVals);\n }\n\n topk(x: T, k: number, sorted: boolean): [T, T] {\n const xVals = x.dataSync();\n return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);\n }\n\n min(x: Tensor, axes: number[]): Tensor {\n axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape);\n }\n\n minimum(a: Tensor, b: Tensor): Tensor {\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.minimum(a, b);\n }\n\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(binaryop_packed_gpu.MIN, a.shape, b.shape) :\n new BinaryOpProgram(binaryop_gpu.MIN, a.shape, b.shape);\n return this.compileAndRun(program, [a, b]);\n }\n\n mod(a: Tensor, b: Tensor): Tensor {\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(binaryop_packed_gpu.MOD, a.shape, b.shape) :\n new BinaryOpProgram(binaryop_gpu.MOD, a.shape, b.shape);\n return this.compileAndRun(program, [a, b]);\n }\n\n max(x: Tensor, axes: number[]): Tensor {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.max(x, axes);\n }\n\n axis_util.assertAxesAreInnerMostDims('max', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape);\n }\n\n maximum(a: Tensor, b: Tensor): Tensor {\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.maximum(a, b);\n }\n\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(binaryop_packed_gpu.MAX, a.shape, b.shape) :\n new BinaryOpProgram(binaryop_gpu.MAX, a.shape, b.shape);\n return this.compileAndRun(program, [a, b]);\n }\n\n all(x: Tensor, axes: number[]): Tensor {\n axis_util.assertAxesAreInnerMostDims('all', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n return this.reduce(a2D, 'all', a2D.dtype).reshape(outShape);\n }\n\n any(x: Tensor, axes: number[]): Tensor {\n axis_util.assertAxesAreInnerMostDims('any', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const inSize = util.sizeFromShape(reduceShape);\n const a2D = x.as2D(-1, inSize);\n return this.reduce(a2D, 'any', a2D.dtype).reshape(outShape);\n }\n\n floorDiv(a: Tensor, b: Tensor): Tensor {\n const op = binaryop_gpu.INT_DIV;\n const outputDtype = 'int32';\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(\n a, b, binaryop_packed_gpu.INT_DIV, outputDtype);\n }\n const program = new BinaryOpProgram(op, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], outputDtype);\n }\n\n add(a: Tensor, b: Tensor): Tensor {\n if (a.dtype === 'complex64' && b.dtype === 'complex64') {\n return this.complexSeparableBinaryOp(a, b, binaryop_gpu.ADD);\n }\n\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.add(a, b);\n }\n\n const dtype = upcastType(a.dtype, b.dtype);\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_gpu.ADD, dtype);\n }\n const program = new BinaryOpProgram(binaryop_gpu.ADD, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], dtype);\n }\n\n private packedUnaryOp(x: TensorInfo, op: string, dtype: DataType) {\n const program = new UnaryOpPackedProgram(x.shape, op);\n return this.compileAndRun(program, [x], dtype);\n }\n\n private packedBinaryOp(\n a: TensorInfo, b: TensorInfo, op: string, dtype: DataType,\n checkOutOfBounds = false) {\n const program =\n new BinaryOpPackedProgram(op, a.shape, b.shape, checkOutOfBounds);\n return this.compileAndRun(program, [a, b], dtype);\n }\n\n /**\n * Computes a complex binary operation that can be decomposed into a simple\n * binary operation on both the real and imagary parts.\n */\n private complexSeparableBinaryOp(a: Tensor, b: Tensor, op: string): Tensor {\n const aData = this.texData.get(a.dataId);\n const bData = this.texData.get(b.dataId);\n\n const [real, imag] = [\n [aData.complexTensors.real, bData.complexTensors.real],\n [aData.complexTensors.imag, bData.complexTensors.imag]\n ].map(complexParts => {\n const [aPart, bPart] = complexParts;\n\n const aHandle = this.makeComplexComponentTensorInfo(a, aPart);\n const bHandle = this.makeComplexComponentTensorInfo(b, bPart);\n\n const program = new BinaryOpProgram(op, a.shape, b.shape);\n return this.compileAndRun(\n program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));\n });\n\n const complex = this.complex(real, imag);\n real.dispose();\n imag.dispose();\n return complex;\n }\n\n // Returns a TensorInfo with the complex shape and the dataId of the\n // underlying part. We need to do this because a reshaped complex tensor is\n // not reflected in its parts.\n private makeComplexComponentTensorInfo(\n complexTensor: Tensor, complexPart: Tensor): TensorInfo {\n return {\n dataId: complexPart.dataId,\n dtype: complexPart.dtype,\n shape: complexTensor.shape\n };\n }\n\n addN(tensors: T[]): T {\n if (tensors.length === 1) {\n return tensors[0];\n }\n\n // Limit the number of uploaded textures for optimization.\n if (tensors.length > env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) {\n const midIndex = Math.floor(tensors.length / 2);\n const leftSide = this.addN(tensors.slice(0, midIndex));\n const rightSide = this.addN(tensors.slice(midIndex));\n return this.addN([leftSide, rightSide]);\n }\n\n const dtype =\n tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));\n const shapes = tensors.map(t => t.shape);\n // We can make sure shapes are identical in op level.\n const usePackedOp = env().getBool('WEBGL_PACK');\n const program = usePackedOp ?\n new AddNPackedProgram(tensors[0].shape, shapes) :\n new AddNProgram(tensors[0].shape, shapes);\n return this.compileAndRun(program, tensors, dtype);\n }\n\n subtract(a: Tensor, b: Tensor): Tensor {\n if (a.dtype === 'complex64' && b.dtype === 'complex64') {\n return this.complexSeparableBinaryOp(a, b, binaryop_gpu.SUB);\n }\n\n if (this.shouldExecuteOnCPU([a, b])) {\n return this.cpuBackend.subtract(a, b);\n }\n const dtype = upcastType(a.dtype, b.dtype);\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n return this.packedBinaryOp(a, b, binaryop_gpu.SUB, a.dtype);\n }\n const program = new BinaryOpProgram(binaryop_gpu.SUB, a.shape, b.shape);\n return this.compileAndRun(program, [a, b], dtype);\n }\n\n pow(a: T, b: Tensor): T {\n const usePackedOp = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');\n const program = usePackedOp ?\n new BinaryOpPackedProgram(binaryop_packed_gpu.POW, a.shape, b.shape) :\n new BinaryOpProgram(binaryop_gpu.POW, a.shape, b.shape);\n const dtype = upcastType(a.dtype, b.dtype);\n return this.compileAndRun(program, [a, b], dtype);\n }\n\n ceil(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.ceil(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_op.CEIL, x.dtype) as T;\n }\n\n const program = new UnaryOpProgram(x.shape, unary_op.CEIL);\n return this.compileAndRun(program, [x]);\n }\n\n floor(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.floor(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_op.FLOOR, x.dtype) as T;\n }\n\n const program = new UnaryOpProgram(x.shape, unary_op.FLOOR);\n return this.compileAndRun(program, [x]);\n }\n\n sign(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SIGN);\n return this.compileAndRun(program, [x]);\n }\n\n isNaN(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.IS_NAN);\n return this.compileAndRun(program, [x], 'bool');\n }\n isInf(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.IS_INF);\n return this.compileAndRun(program, [x], 'bool');\n }\n isFinite(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.IS_FINITE);\n return this.compileAndRun(program, [x], 'bool');\n }\n\n round(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ROUND);\n return this.compileAndRun(program, [x]);\n }\n\n exp(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.exp(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_op.EXP, x.dtype) as T;\n }\n\n const program = new UnaryOpProgram(x.shape, unary_op.EXP);\n return this.compileAndRun(program, [x]);\n }\n\n expm1(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.expm1(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_op.EXPM1, x.dtype) as T;\n }\n\n const program = new UnaryOpProgram(x.shape, unary_op.EXPM1);\n return this.compileAndRun(program, [x]);\n }\n\n softmax(logits: T, dim: number): T {\n const axes = util.parseAxisParam([dim], logits.shape);\n const maxLogit = this.max(logits, axes);\n const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes);\n const a = this.subtract(logits, maxLogit.reshape(expandedShape));\n const b = this.exp(a);\n const sumExp = this.sum(b, axes).reshape(expandedShape);\n\n // TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel\n // modularization.\n return div(b, sumExp);\n }\n\n log(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.log(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_packed_op.LOG, x.dtype) as T;\n }\n\n const program = new UnaryOpProgram(x.shape, unary_op.LOG);\n return this.compileAndRun(program, [x]);\n }\n\n log1p(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.LOG1P);\n return this.compileAndRun(program, [x]);\n }\n\n sqrt(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SQRT);\n return this.compileAndRun(program, [x]);\n }\n\n rsqrt(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.rsqrt(x);\n }\n const program = new UnaryOpProgram(x.shape, unary_op.RSQRT);\n return this.compileAndRun(program, [x]);\n }\n\n reciprocal(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.RECIPROCAL);\n return this.compileAndRun(program, [x]);\n }\n\n relu(x: T): T {\n let program: UnaryOpProgram|UnaryOpPackedProgram;\n if (env().getBool('WEBGL_PACK')) {\n program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU);\n } else {\n program = new UnaryOpProgram(x.shape, unary_op.RELU);\n }\n return this.compileAndRun(program, [x]);\n }\n\n relu6(x: T): T {\n let program: UnaryOpProgram|UnaryOpPackedProgram;\n if (env().getBool('WEBGL_PACK')) {\n program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU6);\n } else {\n program = new UnaryOpProgram(x.shape, unary_op.RELU6);\n }\n return this.compileAndRun(program, [x]);\n }\n\n prelu(x: T, alpha: T): T {\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(\n binaryop_packed_gpu.PRELU, x.shape, alpha.shape) :\n new BinaryOpProgram(binaryop_gpu.PRELU, x.shape, alpha.shape);\n return this.compileAndRun(program, [x, alpha]);\n }\n\n elu(x: T): T {\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_packed_op.ELU, x.dtype) as T;\n }\n const program = new UnaryOpProgram(x.shape, unary_op.ELU);\n return this.compileAndRun(program, [x]);\n }\n\n eluDer(dy: T, y: T): T {\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(\n binaryop_packed_gpu.ELU_DER, dy.shape, y.shape) :\n new BinaryOpProgram(binaryop_gpu.ELU_DER, dy.shape, y.shape);\n return this.compileAndRun(program, [dy, y]);\n }\n\n selu(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SELU);\n return this.compileAndRun(program, [x]);\n }\n\n int(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.TO_INT);\n return this.compileAndRun(program, [x], 'int32');\n }\n\n clip(x: T, min: number, max: number): T {\n let program;\n if (env().getBool('WEBGL_PACK_CLIP')) {\n program = new ClipPackedProgram(x.shape);\n } else {\n program = new ClipProgram(x.shape);\n }\n const customSetup = program.getCustomSetupFunc(min, max);\n return this.compileAndRun(program, [x], null, customSetup);\n }\n\n abs(x: T): T {\n if (this.shouldExecuteOnCPU([x])) {\n return this.cpuBackend.abs(x);\n }\n\n if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {\n return this.packedUnaryOp(x, unary_op.ABS, x.dtype) as T;\n }\n\n const program = new UnaryOpProgram(x.shape, unary_op.ABS);\n return this.compileAndRun(program, [x]);\n }\n\n complexAbs(x: T): T {\n const xData = this.texData.get(x.dataId);\n\n const program = new ComplexAbsProgram(x.shape);\n const inputs = [\n this.makeComplexComponentTensorInfo(x, xData.complexTensors.real),\n this.makeComplexComponentTensorInfo(x, xData.complexTensors.imag),\n ];\n\n return this.compileAndRun(program, inputs) as T;\n }\n\n sigmoid(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SIGMOID);\n return this.compileAndRun(program, [x]);\n }\n\n softplus(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SOFTPLUS);\n return this.compileAndRun(program, [x]);\n }\n\n sin(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SIN);\n return this.compileAndRun(program, [x]);\n }\n\n cos(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.COS);\n return this.compileAndRun(program, [x]);\n }\n\n tan(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.TAN);\n return this.compileAndRun(program, [x]);\n }\n\n asin(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ASIN);\n return this.compileAndRun(program, [x]);\n }\n\n acos(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ACOS);\n return this.compileAndRun(program, [x]);\n }\n\n atan(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ATAN);\n return this.compileAndRun(program, [x]);\n }\n\n atan2(a: T, b: T): T {\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(binaryop_packed_gpu.ATAN2, a.shape, b.shape) :\n new BinaryOpProgram(binaryop_gpu.ATAN2, a.shape, b.shape);\n return this.compileAndRun(program, [a, b]);\n }\n\n sinh(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.SINH);\n return this.compileAndRun(program, [x]);\n }\n\n cosh(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.COSH);\n return this.compileAndRun(program, [x]);\n }\n\n tanh(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.TANH);\n return this.compileAndRun(program, [x]);\n }\n\n asinh(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ASINH);\n return this.compileAndRun(program, [x]);\n }\n\n acosh(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ACOSH);\n return this.compileAndRun(program, [x]);\n }\n\n atanh(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ATANH);\n return this.compileAndRun(program, [x]);\n }\n\n erf(x: T): T {\n const program = new UnaryOpProgram(x.shape, unary_op.ERF);\n return this.compileAndRun(program, [x]);\n }\n\n step(x: T, alpha: number): T {\n const program = new UnaryOpProgram(x.shape, unary_op.STEP(alpha));\n return this.compileAndRun(program, [x]);\n }\n\n private conv2dByMatMul(\n x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor,\n activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {\n // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the\n // result from 2D to 4D.\n const xShape = x.shape;\n const xTexData = this.texData.get(x.dataId);\n const sharedMatMulDim = convInfo.inChannels;\n const outerShapeX = xShape[0] * xShape[1] * xShape[2];\n const outerShapeFilter = convInfo.outChannels;\n const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n const transposeA = false;\n const transposeB = false;\n\n // TODO: Once reduction ops are packed, batchMatMul will always be packed\n // and we can remove this condition.\n const batchMatMulWillBeUnpacked =\n (outerShapeX === 1 || outerShapeFilter === 1) &&\n sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;\n const reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked;\n\n if (batchMatMulWillBeUnpacked || !env().getBool('WEBGL_LAZILY_UNPACK') ||\n !env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ||\n !reshapeWillBeExpensive) {\n const targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] :\n xShape[0] * xShape[2] * xShape[3];\n const xReshaped = this.reshape(x, [1, targetShape, convInfo.inChannels]);\n const filterReshaped =\n this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);\n\n return this.reshape(\n this.fusedBatchMatMul({\n a: xReshaped as Tensor3D,\n b: filterReshaped as Tensor3D,\n transposeA,\n transposeB,\n bias,\n activation,\n preluActivationWeights\n }),\n convInfo.outShape);\n }\n\n // Following optimization is specific to packed |x| with odd row count\n // (For example, in channelLast mode, 'row count' refers to x.shape[2]):\n // we avoid expensive packed 2x2 reshape by padding row count to next,\n // even number. When x.shape[2] is odd, the result of packed batchMatMul is\n // the same (has the same texture layout and and values in the texture) as\n // it is for even x.shape[2] + 1. We make the odd-rows tensor to look like\n // even-rows tensor before the operation and, after the batchMatMul,\n // fix the even-rows result to have odd number of rows.\n const targetShape = isChannelsLast ?\n xShape[0] * xShape[1] * (xShape[2] + 1) :\n xShape[0] * xShape[2] * (xShape[3] + 1);\n const xReshaped: TensorInfo = {\n dataId: x.dataId,\n shape: [1, targetShape, convInfo.inChannels],\n dtype: x.dtype\n };\n // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.\n // Decrementing row count, after batchMatMul->...->compileProgram leads to\n // invalid row count within the reference in GPGPUBinary.inShapeInfos.\n // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos\n // in compileProgram method, but that would affect compilation of all\n // programs - instead, provide a copy here, with even row count, before\n // calling batchMatMul->...->compileProgram and after that, the original\n // xTexData.shape is restored.\n const originalXTexDataShape = xTexData.shape;\n xTexData.shape = xTexData.shape.slice();\n xTexData.shape[xTexData.shape.length - 2]++;\n util.assert(\n webgl_util.isReshapeFree(xTexData.shape, xReshaped.shape),\n () => `packed reshape ${xTexData.shape} to ${\n xReshaped.shape} isn't free`);\n const filterReshaped =\n this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);\n\n const pointwiseConv = this.fusedBatchMatMul({\n a: xReshaped as Tensor3D,\n b: filterReshaped as Tensor3D,\n transposeA,\n transposeB,\n bias,\n activation,\n preluActivationWeights\n });\n const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);\n util.assert(\n pointwiseConvTexData.isPacked,\n () => 'batchMatMul result is expected to be packed');\n // Restore the input shape to original.\n xTexData.shape = originalXTexDataShape;\n // Set the output shape - there is no need for expensive reshape as data\n // layout is already correct.\n pointwiseConvTexData.shape = convInfo.outShape;\n return ENGINE.makeTensorFromDataId(\n pointwiseConv.dataId, convInfo.outShape, pointwiseConv.dtype) as\n Tensor4D;\n }\n\n private conv2dWithIm2Row(\n x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor,\n activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {\n // Rearranges conv2d input so each block to be convolved over forms the\n // column of a new matrix with shape [filterWidth * filterHeight *\n // inChannels, outHeight * outWidth]. The filter is also rearranged so each\n // output channel forms a row of a new matrix with shape [outChannels,\n // filterWidth * filterHeight * inChannels]. The convolution is then\n // computed by multiplying these matrices and reshaping the result.\n const {\n filterWidth,\n filterHeight,\n inChannels,\n outWidth,\n outHeight,\n dataFormat\n } = convInfo;\n\n const isChannelsLast = dataFormat === 'channelsLast';\n\n const sharedDim = filterWidth * filterHeight * inChannels;\n const numCols = outHeight * outWidth;\n const x2ColShape = [sharedDim, numCols];\n const transposeA = true;\n const transposeB = false;\n\n const xSqueezed = x.squeeze([0]);\n const w2Row = filter.reshape([1, sharedDim, -1]);\n\n const im2ColProgram =\n new Im2ColPackedProgram(x2ColShape, xSqueezed.shape, convInfo);\n const im2Col: Tensor3D =\n this.compileAndRun(im2ColProgram, [xSqueezed]).reshape([\n 1, x2ColShape[0], x2ColShape[1]\n ]);\n\n const hasBias = bias != null;\n const hasPreluActivationWeights = preluActivationWeights != null;\n const fusedActivation =\n activation ? mapActivationToShaderProgram(activation, true) : null;\n const matmulProgram = new MatMulPackedProgram(\n im2Col.shape, [1, numCols, convInfo.outChannels], transposeA,\n transposeB, hasBias, fusedActivation, hasPreluActivationWeights);\n const inputs: TensorInfo[] = [im2Col, w2Row];\n if (bias) {\n inputs.push(bias);\n }\n if (hasPreluActivationWeights) {\n inputs.push(preluActivationWeights);\n }\n const product = this.compileAndRun(matmulProgram, inputs);\n\n if (isChannelsLast) {\n return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);\n } else {\n return product.reshape([1, convInfo.outChannels, outHeight, outWidth]);\n }\n }\n\n fusedConv2d(\n {input, filter, convInfo, bias, activation, preluActivationWeights}:\n FusedConv2DConfig): Tensor4D {\n if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&\n convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&\n convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&\n (convInfo.padInfo.type === 'SAME' ||\n convInfo.padInfo.type === 'VALID')) {\n return this.conv2dByMatMul(\n input, filter, convInfo, bias, activation, preluActivationWeights);\n }\n if (env().getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) {\n return this.conv2dWithIm2Row(\n input, filter, convInfo, bias, activation, preluActivationWeights);\n }\n\n const hasBias = bias != null;\n const hasPreluActivationWeights = preluActivationWeights != null;\n const fusedActivation =\n activation ? mapActivationToShaderProgram(activation, false) : null;\n const program = new Conv2DProgram(\n convInfo, hasBias, fusedActivation, hasPreluActivationWeights);\n const inputs: TensorInfo[] = [input, filter];\n if (bias) {\n inputs.push(bias);\n }\n if (preluActivationWeights) {\n inputs.push(preluActivationWeights);\n }\n return this.compileAndRun(program, inputs);\n }\n\n conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&\n convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&\n convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&\n (convInfo.padInfo.type === 'SAME' ||\n convInfo.padInfo.type === 'VALID')) {\n return this.conv2dByMatMul(x, filter, convInfo);\n }\n if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {\n return this.conv2dWithIm2Row(x, filter, convInfo);\n }\n const program = new Conv2DProgram(convInfo);\n return this.compileAndRun(program, [x, filter]);\n }\n\n conv2dDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n const program = new Conv2DDerInputProgram(convInfo);\n return this.compileAndRun(program, [dy, filter]);\n }\n\n conv2dDerFilter(x: Tensor4D, dy: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n const program = new Conv2DDerFilterProgram(convInfo);\n return this.compileAndRun(program, [x, dy]);\n }\n\n fusedDepthwiseConv2D(\n {input, filter, convInfo, bias, activation, preluActivationWeights}:\n FusedConv2DConfig): Tensor4D {\n const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&\n convInfo.strideWidth <= 2 &&\n convInfo.outChannels / convInfo.inChannels === 1;\n const fusedActivation = activation ?\n mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :\n null;\n const inputs: Tensor[] = [input, filter];\n\n const hasBias = bias != null;\n const hasPreluActivationWeights = preluActivationWeights != null;\n if (hasBias) {\n inputs.push(bias);\n }\n if (hasPreluActivationWeights) {\n inputs.push(preluActivationWeights);\n }\n\n let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;\n if (shouldPackDepthwiseConv) {\n program = new DepthwiseConvPacked2DProgram(\n convInfo, hasBias, fusedActivation, hasPreluActivationWeights);\n return this.compileAndRun(program, inputs);\n }\n\n program = new DepthwiseConv2DProgram(\n convInfo, hasBias, fusedActivation, hasPreluActivationWeights);\n return this.compileAndRun(program, inputs);\n }\n\n depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;\n if (env().getBool('WEBGL_PACK_DEPTHWISECONV') &&\n convInfo.strideWidth <= 2 &&\n convInfo.outChannels / convInfo.inChannels === 1) {\n program = new DepthwiseConvPacked2DProgram(convInfo);\n return this.compileAndRun(program, [x, filter]);\n }\n\n program = new DepthwiseConv2DProgram(convInfo);\n return this.compileAndRun(program, [x, filter]);\n }\n\n depthwiseConv2DDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n const program = new DepthwiseConv2DDerInputProgram(convInfo);\n return this.compileAndRun(program, [dy, filter]);\n }\n\n depthwiseConv2DDerFilter(x: Tensor4D, dy: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n const program = new DepthwiseConv2DDerFilterProgram(convInfo);\n return this.compileAndRun(program, [x, dy]);\n }\n\n conv3d(x: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const program = new Conv3DProgram(convInfo);\n return this.compileAndRun(program, [x, filter]);\n }\n\n conv3dDerInput(dy: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo):\n Tensor5D {\n const program = new Conv3DDerInputProgram(convInfo);\n return this.compileAndRun(program, [dy, filter]);\n }\n\n conv3dDerFilter(x: Tensor5D, dy: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const program = new Conv3DDerFilterProgram(convInfo);\n return this.compileAndRun(program, [x, dy]);\n }\n\n maxPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n const program = new Pool2DProgram(convInfo, 'max', false);\n return this.compileAndRun(program, [x]);\n }\n\n avgPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n const program = new Pool2DProgram(convInfo, 'avg', false);\n return this.compileAndRun(program, [x], 'float32');\n }\n\n maxPoolBackprop(dy: Tensor4D, x: Tensor4D, y: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n const getPositions = true;\n const maxPoolPositionsProgram =\n new Pool2DProgram(convInfo, 'max', getPositions);\n const maxPoolPositions: Tensor4D =\n this.compileAndRun(maxPoolPositionsProgram, [x]);\n\n const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);\n const result = this.compileAndRun(\n maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);\n maxPoolPositions.dispose();\n return result as Tensor4D;\n }\n\n avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);\n return this.compileAndRun(avgPoolBackpropProgram, [dy], x.dtype);\n }\n\n cast(x: T, dtype: DataType): T {\n return backend_util.castTensor(x, dtype, this);\n }\n\n unstack(x: Tensor, axis: number): Tensor[] {\n const num = x.shape[axis];\n const outShape: number[] = new Array(x.rank - 1);\n let outIndex = 0;\n for (let i = 0; i < x.rank; i++) {\n if (i !== axis) {\n outShape[outIndex++] = x.shape[i];\n }\n }\n\n const begin = new Array(x.rank).fill(0);\n const size = x.shape.slice();\n size[axis] = 1;\n const res = new Array(num);\n for (let i = 0; i < res.length; i++) {\n begin[axis] = i;\n res[i] = this.slice(x, begin, size).reshape(outShape);\n }\n return res;\n }\n\n avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const program = new Pool3DProgram(convInfo, 'avg', false);\n return this.compileAndRun(program, [x], 'float32');\n }\n\n avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const avgPool3dBackpropProgram = new AvgPool3DBackpropProgram(convInfo);\n return this.compileAndRun(avgPool3dBackpropProgram, [dy], x.dtype);\n }\n\n maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const program = new Pool3DProgram(convInfo, 'max', false);\n return this.compileAndRun(program, [x], 'float32');\n }\n\n maxPool3dBackprop(\n dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const getPositions = true;\n const maxPool3dPositionsProgram =\n new Pool3DProgram(convInfo, 'max', getPositions);\n const maxPool3dPositions: Tensor5D =\n this.compileAndRun(maxPool3dPositionsProgram, [x]);\n const maxPool3dBackPropProgram = new MaxPool3DBackpropProgram(convInfo);\n const result = this.compileAndRun(\n maxPool3dBackPropProgram, [dy, maxPool3dPositions], x.dtype);\n maxPool3dPositions.dispose();\n return result as Tensor5D;\n }\n\n reshape(x: Tensor, shape: ShapeMap[R]): Tensor {\n const texData = this.texData.get(x.dataId);\n if (texData.isPacked && !webgl_util.isReshapeFree(x.shape, shape) &&\n !(texData.texture !== null &&\n webgl_util.isReshapeFree(texData.shape, shape))) {\n const info = this.packedReshape(x, shape);\n return ENGINE.makeTensorFromDataId(info.dataId, info.shape, info.dtype) as\n Tensor;\n }\n return backend_util.reshapeTensor(x, shape);\n }\n\n resizeBilinear(\n x: Tensor4D, newHeight: number, newWidth: number,\n alignCorners: boolean): Tensor4D {\n const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?\n new ResizeBilinearPackedProgram(\n x.shape, newHeight, newWidth, alignCorners) :\n new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners);\n return this.compileAndRun(program, [x], 'float32');\n }\n\n resizeBilinearBackprop(dy: Tensor4D, x: Tensor4D, alignCorners: boolean):\n Tensor4D {\n const program = new ResizeBilinearBackpropProgram(dy, x, alignCorners);\n\n return this.compileAndRun(program, [dy]);\n }\n\n resizeNearestNeighbor(\n x: Tensor4D, newHeight: number, newWidth: number,\n alignCorners: boolean): Tensor4D {\n const program = new ResizeNearestNeighborProgram(\n x.shape, newHeight, newWidth, alignCorners);\n return this.compileAndRun(program, [x]);\n }\n\n resizeNearestNeighborBackprop(\n dy: Tensor4D, x: Tensor4D, alignCorners: boolean): Tensor4D {\n const program =\n new ResizeNearestNeigborBackpropProgram(dy, x, alignCorners);\n return this.compileAndRun(program, [dy]);\n }\n\n multinomial(\n logits: Tensor2D, normalized: boolean, numSamples: number,\n seed: number): Tensor2D {\n const probs = normalized ? logits : softmax(logits);\n const batchSize = probs.shape[0];\n const numOutcomes = probs.shape[1];\n const program = new MultinomialProgram(batchSize, numOutcomes, numSamples);\n const customSetup = program.getCustomSetupFunc(seed);\n return this.compileAndRun(program, [probs], 'int32', customSetup);\n }\n\n oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number):\n Tensor2D {\n const program = new OneHotProgram(indices.size, depth, onValue, offValue);\n return this.compileAndRun(program, [indices]);\n }\n\n diag(x: Tensor): Tensor {\n const program = new DiagProgram(x.size);\n return this.compileAndRun(program, [x]);\n }\n\n nonMaxSuppression(\n boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,\n iouThreshold: number, scoreThreshold: number): Tensor1D {\n warn(\n 'tf.nonMaxSuppression() in webgl locks the UI thread. ' +\n 'Call tf.nonMaxSuppressionAsync() instead');\n const boxesVals = boxes.dataSync();\n const scoresVals = scores.dataSync();\n return nonMaxSuppressionV3(\n boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);\n }\n\n cropAndResize(\n image: Tensor4D, boxes: Tensor2D, boxIndex: Tensor1D,\n cropSize: [number, number], method: 'bilinear'|'nearest',\n extrapolationValue: number): Tensor4D {\n const program = new CropAndResizeProgram(\n image.shape, boxes.shape, cropSize, method, extrapolationValue);\n return this.compileAndRun(program, [image, boxes, boxIndex], 'float32');\n }\n\n depthToSpace(x: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):\n Tensor4D {\n util.assert(\n blockSize > 1,\n () =>\n `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);\n\n const batchSize = x.shape[0];\n const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];\n const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];\n const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];\n\n const outputHeight = inputHeight * blockSize;\n const outputWidth = inputWidth * blockSize;\n const outputDepth = inputDepth / (blockSize * blockSize);\n\n const outputShape = (dataFormat === 'NHWC') ?\n [batchSize, outputHeight, outputWidth, outputDepth] :\n [batchSize, outputDepth, outputHeight, outputWidth];\n\n const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);\n return this.compileAndRun(program, [x]);\n }\n\n split(x: T, sizeSplits: number[], axis: number): T[] {\n return split(x, sizeSplits, axis);\n }\n\n scatterND(\n indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor {\n const {sliceRank, numUpdates, sliceSize, strides, outputSize} =\n scatter_nd_util.calculateShapes(updates, indices, shape);\n\n const flattenShape = [outputSize / sliceSize, sliceSize];\n const flattenIndices = indices.reshape([numUpdates, sliceRank]);\n const flattenX = updates.reshape([numUpdates, sliceSize]);\n\n if (outputSize === 0) {\n return backend_util.reshapeTensor(tensor([]), shape);\n }\n const defaultValue = scalar(0);\n const program = new ScatterProgram(\n numUpdates, sliceRank, flattenIndices.rank, flattenX.rank, strides,\n flattenShape);\n const res: Tensor =\n this.compileAndRun(program, [flattenX, flattenIndices, defaultValue]);\n return res.reshape(shape);\n }\n\n sparseToDense(\n sparseIndices: Tensor, sparseValues: Tensor, outputShape: ShapeMap[R],\n defaultValue: Scalar): Tensor {\n const {sliceRank, numUpdates, strides, outputSize} =\n scatter_nd_util.calculateShapes(\n sparseValues, sparseIndices, outputShape);\n\n const sumDupeIndices = false;\n const program = new ScatterProgram(\n numUpdates, sliceRank, sparseIndices.rank, sparseValues.rank, strides,\n [outputSize, 1], sumDupeIndices);\n const res: Tensor = this.compileAndRun(\n program, [sparseValues, sparseIndices, defaultValue]);\n return res.reshape(outputShape);\n }\n\n fft(x: Tensor2D): Tensor2D {\n const inverse = false;\n return this.fftImpl(x, inverse);\n }\n\n ifft(x: Tensor2D): Tensor2D {\n const inverse = true;\n return this.fftImpl(x, inverse);\n }\n\n private fftImpl(x: Tensor2D, inverse: boolean): Tensor2D {\n const xData = this.texData.get(x.dataId);\n\n const realProgram =\n new FFTProgram(fft_gpu.COMPLEX_FFT.REAL, x.shape, inverse);\n const imagProgram =\n new FFTProgram(fft_gpu.COMPLEX_FFT.IMAG, x.shape, inverse);\n const inputs = [\n this.makeComplexComponentTensorInfo(x, xData.complexTensors.real),\n this.makeComplexComponentTensorInfo(x, xData.complexTensors.imag),\n ];\n\n const real = this.compileAndRun(realProgram, inputs);\n const imag = this.compileAndRun(imagProgram, inputs);\n const complex = this.complex(real, imag).as2D(x.shape[0], x.shape[1]);\n real.dispose();\n imag.dispose();\n return complex;\n }\n\n gatherND(x: Tensor, indices: Tensor): Tensor {\n const indicesShape = indices.shape;\n const sliceRank = indicesShape[indicesShape.length - 1];\n\n const [resultShape, numSlices, sliceSize, strides] =\n gather_nd_util.prepareAndValidate(x, indices);\n\n const flattenIndices = indices.reshape([numSlices, sliceRank]);\n const flattenX = x.reshape([x.size / sliceSize, sliceSize]);\n const program =\n new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);\n const res: Tensor = this.compileAndRun(program, [flattenX, flattenIndices]);\n return res.reshape(resultShape);\n }\n\n fill(\n shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor {\n dtype = dtype || inferDtype(value);\n\n if (dtype === 'string') {\n // String type should be handled in CPU memory.\n const values = getArrayFromDType(dtype, sizeFromShape(shape));\n values.fill(value as string);\n return ENGINE.makeTensor(values, shape, dtype, this) as Tensor;\n } else {\n const program = new FillProgram(shape, value as number);\n const customSetup = program.getCustomSetupFunc(value as number);\n return this.compileAndRun(program, [], dtype, customSetup);\n }\n }\n\n onesLike(x: Tensor): Tensor {\n if (x.dtype === 'string') {\n throw new Error('onesLike is not supported under string dtype');\n } else {\n // TODO(cais, smilkov): Add WebGL shader for onesLike:\n // https://github.com/tensorflow/tfjs/issues/1293\n return this.fill(x.shape, 1, x.dtype);\n }\n }\n\n zerosLike(x: Tensor): Tensor {\n return this.fill(x.shape, x.dtype === 'string' ? '' : 0, x.dtype);\n }\n\n linspace(start: number, stop: number, num: number): Tensor1D {\n // TODO: Use CPU implementation due to the precision problem in Safari.\n return backend_util.linspaceImpl(start, stop, num);\n }\n\n makeTensorInfo(shape: number[], dtype: DataType): TensorInfo {\n const dataId = this.write(null /* values */, shape, dtype);\n this.texData.get(dataId).usage = null;\n return {dataId, shape, dtype};\n }\n\n private makeOutput(shape: number[], dtype: DataType): T {\n const {dataId} = this.makeTensorInfo(shape, dtype);\n return ENGINE.makeTensorFromDataId(dataId, shape, dtype, this) as T;\n }\n\n private unpackTensor(input: TensorInfo): TensorInfo {\n const program = new UnpackProgram(input.shape);\n return this.runWebGLProgram(program, [input], input.dtype);\n }\n\n private packTensor(input: TensorInfo): TensorInfo {\n const program = new PackProgram(input.shape);\n const preventEagerUnpackingOutput = true;\n return this.runWebGLProgram(\n program, [input], input.dtype, null /* customSetup */,\n preventEagerUnpackingOutput);\n }\n\n private packedReshape(input: TensorInfo, afterShape: number[]): TensorInfo {\n const input3DShape = [\n webgl_util.getBatchDim(input.shape),\n ...webgl_util.getRowsCols(input.shape)\n ] as [number, number, number];\n const input3D: TensorInfo = {\n dtype: input.dtype,\n shape: input3DShape,\n dataId: input.dataId\n };\n const afterShapeAs3D = [\n webgl_util.getBatchDim(afterShape), ...webgl_util.getRowsCols(afterShape)\n ] as [number, number, number];\n\n const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);\n const preventEagerUnpackingOfOutput = true;\n const output = this.runWebGLProgram(\n program, [input3D], input.dtype, null /* customSetup */,\n preventEagerUnpackingOfOutput);\n return {dataId: output.dataId, shape: afterShape, dtype: output.dtype};\n }\n\n private decode(dataId: DataId): TensorInfo {\n const texData = this.texData.get(dataId);\n const {isPacked, shape, dtype} = texData;\n const shapeAs3D =\n webgl_util.getShapeAs3D(shape) as [number, number, number];\n let program;\n if (isPacked) {\n program = new DecodeMatrixPackedProgram(shapeAs3D);\n } else {\n program = new DecodeMatrixProgram(shapeAs3D);\n }\n const preventEagerUnpackingOfOutput = true;\n const out = this.runWebGLProgram(\n program, [{shape: shapeAs3D, dtype, dataId}], dtype,\n null /* customSetup */, preventEagerUnpackingOfOutput);\n return {dtype, shape, dataId: out.dataId};\n }\n\n runWebGLProgram(\n program: GPGPUProgram, inputs: TensorInfo[], outputDtype: DataType,\n customSetup?: (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => void,\n preventEagerUnpackingOfOutput = false): TensorInfo {\n const output = this.makeTensorInfo(program.outputShape, outputDtype);\n const outData = this.texData.get(output.dataId);\n if (program.packedOutput) {\n outData.isPacked = true;\n }\n if (program.outPackingScheme === tex_util.PackingScheme.DENSE) {\n const texelShape = tex_util.getDenseTexShape(program.outputShape);\n // For a densely packed output, we explicitly set texShape\n // so it doesn't get assigned later according to our typical packing\n // scheme wherein a single texel can only contain values from adjacent\n // rows/cols.\n outData.texShape = texelShape.map(d => d * 2) as [number, number];\n }\n if (program.outTexUsage != null) {\n outData.usage = program.outTexUsage;\n }\n if (sizeFromShape(output.shape) === 0) {\n // Short-circuit the computation since the result is empty (has 0 in its\n // shape).\n outData.values = getTypedArrayFromDType(output.dtype as 'float32', 0);\n return output;\n }\n\n const dataToDispose: TensorInfo[] = [];\n const inputsData: TensorData[] = inputs.map(input => {\n if (input.dtype === 'complex64') {\n throw new Error(\n `GPGPUProgram does not support complex64 input. For complex64 ` +\n `dtypes, please separate the program into real and imaginary ` +\n `parts.`);\n }\n\n let texData = this.texData.get(input.dataId);\n\n if (texData.texture == null) {\n if (!program.packedInputs &&\n util.sizeFromShape(input.shape) <=\n env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {\n // Upload small tensors that live on the CPU as uniforms, not as\n // textures. Do this only when the environment supports 32bit floats\n // due to problems when comparing 16bit floats with 32bit floats.\n // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it\n // possible for packed shaders to sample from uniforms.\n return {\n shape: input.shape,\n texData: null,\n isUniform: true,\n uniformValues: texData.values as TypedArray\n };\n }\n\n // This ensures that if a packed program's inputs have not yet been\n // uploaded to the GPU, they get uploaded as packed right off the bat.\n if (program.packedInputs) {\n texData.isPacked = true;\n texData.shape = input.shape;\n }\n } else if (!!texData.isPacked !== !!program.packedInputs) {\n input = texData.isPacked ? this.unpackTensor(input) :\n this.packTensor(input);\n dataToDispose.push(input);\n texData = this.texData.get(input.dataId);\n } else if (\n texData.isPacked &&\n !webgl_util.isReshapeFree(texData.shape, input.shape)) {\n // This is a special case where a texture exists for a tensor\n // but the shapes are incompatible (due to packing constraints) because\n // the tensor did not have a chance to go through the packed reshape\n // shader. This only happens when we reshape the *same* tensor to form\n // *distinct* inputs to an op, e.g. dotting a vector with itself. This\n // case will disappear once packed uploading is the default.\n\n const savedInput = input;\n const targetShape = input.shape;\n\n input.shape = texData.shape;\n input = this.packedReshape(input as Tensor, targetShape);\n dataToDispose.push(input);\n texData = this.texData.get(input.dataId);\n\n savedInput.shape = targetShape;\n }\n\n this.uploadToGPU(input.dataId);\n return {shape: input.shape, texData, isUniform: false};\n });\n\n this.uploadToGPU(output.dataId);\n const outputData:\n TensorData = {shape: output.shape, texData: outData, isUniform: false};\n const key = gpgpu_math.makeShaderKey(program, inputsData, outputData);\n const binary = this.getAndSaveBinary(key, () => {\n return gpgpu_math.compileProgram(\n this.gpgpu, program, inputsData, outputData);\n });\n const shouldTimeProgram = this.activeTimers != null;\n let query: WebGLQuery|CPUTimerQuery;\n if (shouldTimeProgram) {\n query = this.startTimer();\n }\n\n gpgpu_math.runProgram(\n this.gpgpu, binary, inputsData, outputData, customSetup);\n\n dataToDispose.forEach(info => this.disposeData(info.dataId));\n\n if (shouldTimeProgram) {\n query = this.endTimer(query);\n this.activeTimers.push(\n {name: program.constructor.name, query: this.getQueryTime(query)});\n }\n\n if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&\n preventEagerUnpackingOfOutput === false) {\n const unpacked = this.unpackTensor(output);\n this.disposeData(output.dataId);\n return unpacked;\n }\n return output;\n }\n\n compileAndRun(\n program: GPGPUProgram, inputs: TensorInfo[], outputDtype?: DataType,\n customSetup?: (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => void,\n preventEagerUnpackingOfOutput = false): K {\n outputDtype = outputDtype || inputs[0].dtype;\n const outInfo = this.runWebGLProgram(\n program, inputs, outputDtype, customSetup,\n preventEagerUnpackingOfOutput);\n return ENGINE.makeTensorFromDataId(\n outInfo.dataId, outInfo.shape, outInfo.dtype) as {} as K;\n }\n\n private getAndSaveBinary(key: string, getBinary: () => GPGPUBinary):\n GPGPUBinary {\n if (!(key in this.binaryCache)) {\n this.binaryCache[key] = getBinary();\n }\n return this.binaryCache[key];\n }\n\n getTextureManager(): TextureManager {\n return this.textureManager;\n }\n\n private disposed = false;\n\n dispose() {\n if (this.disposed) {\n return;\n }\n // Avoid disposing the compiled webgl programs during unit testing because\n // it slows down test execution.\n if (!env().getBool('IS_TEST')) {\n const allKeys = Object.keys(this.binaryCache);\n allKeys.forEach(key => {\n this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);\n delete this.binaryCache[key];\n });\n }\n this.textureManager.dispose();\n if (this.canvas != null &&\n (typeof (HTMLCanvasElement) !== 'undefined' &&\n this.canvas instanceof HTMLCanvasElement)) {\n this.canvas.remove();\n } else {\n this.canvas = null;\n }\n if (this.gpgpuCreatedLocally) {\n this.gpgpu.program = null;\n this.gpgpu.dispose();\n }\n this.disposed = true;\n }\n\n floatPrecision(): 16|32 {\n if (this.floatPrecisionValue == null) {\n this.floatPrecisionValue = tidy(() => {\n if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {\n // Momentarily switching DEBUG flag to false so we don't throw an\n // error trying to upload a small value.\n const debugFlag = env().getBool('DEBUG');\n env().set('DEBUG', false);\n const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];\n env().set('DEBUG', debugFlag);\n\n if (underflowCheckValue > 0) {\n return 32;\n }\n }\n return 16;\n });\n }\n return this.floatPrecisionValue;\n }\n /** Returns the smallest representable number. */\n epsilon(): number {\n return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;\n }\n\n private uploadToGPU(dataId: DataId): void {\n const texData = this.texData.get(dataId);\n const {shape, dtype, values, texture, usage, isPacked} = texData;\n\n if (texture != null) {\n // Array is already on GPU. No-op.\n return;\n }\n const shouldTimeProgram = this.activeTimers != null;\n let start: number;\n if (shouldTimeProgram) {\n start = util.now();\n }\n\n let texShape = texData.texShape;\n if (texShape == null) {\n texShape = webgl_util.getTextureShapeFromLogicalShape(shape, isPacked);\n texData.texShape = texShape;\n }\n\n if (values != null) {\n const shapeAs3D = webgl_util.getShapeAs3D(shape);\n\n let program;\n let width = texShape[1], height = texShape[0];\n const isByteArray = values instanceof Uint8Array;\n\n if (isPacked) {\n [width, height] = tex_util.getPackedMatrixTextureShapeWidthHeight(\n texShape[0], texShape[1]);\n program = new EncodeMatrixPackedProgram(\n shapeAs3D, [height, width], isByteArray);\n } else {\n program =\n new EncodeMatrixProgram(shapeAs3D, [height, width], isByteArray);\n }\n\n const tempDenseInputHandle = this.makeTensorInfo([height, width], dtype);\n if (isByteArray) {\n this.texData.get(tempDenseInputHandle.dataId).usage =\n TextureUsage.PIXELS;\n } else {\n this.texData.get(tempDenseInputHandle.dataId).usage =\n TextureUsage.UPLOAD;\n }\n this.gpgpu.uploadDenseMatrixToTexture(\n this.getTexture(tempDenseInputHandle.dataId), width, height,\n values as TypedArray);\n\n // We want the output to remain packed regardless of the value of\n // WEBGL_PACK.\n const preventEagerUnpacking = true;\n const encodedOutputTarget = this.runWebGLProgram(\n program, [tempDenseInputHandle], dtype, null, preventEagerUnpacking);\n\n // Have the original texture assume the identity of the encoded output.\n const outputTexData = this.texData.get(encodedOutputTarget.dataId);\n texData.texture = outputTexData.texture;\n texData.texShape = outputTexData.texShape;\n texData.isPacked = outputTexData.isPacked;\n texData.usage = outputTexData.usage;\n\n this.disposeData(tempDenseInputHandle.dataId);\n this.texData.delete(encodedOutputTarget.dataId);\n\n // Once uploaded, don't store the values on cpu.\n texData.values = null;\n if (shouldTimeProgram) {\n this.uploadWaitMs += util.now() - start;\n }\n } else {\n const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);\n texData.texture = newTexture;\n }\n }\n\n private convertAndCacheOnCPU(dataId: DataId, float32Values?: Float32Array):\n TypedArray {\n const texData = this.texData.get(dataId);\n const {dtype} = texData;\n\n this.releaseGPUData(dataId);\n\n if (float32Values != null) {\n texData.values = float32ToTypedArray(float32Values, dtype as 'float32');\n }\n return texData.values as TypedArray;\n }\n\n private acquireTexture(\n texShape: [number, number], texType: TextureUsage, dtype: DataType,\n isPacked: boolean): WebGLTexture {\n this.numBytesInGPU += this.computeBytes(texShape, dtype);\n if (!this.warnedAboutMemory &&\n this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {\n const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);\n this.warnedAboutMemory = true;\n console.warn(\n `High memory usage in GPU: ${mb} MB, ` +\n `most likely due to a memory leak`);\n }\n return this.textureManager.acquireTexture(texShape, texType, isPacked);\n }\n\n private computeBytes(shape: [number, number], dtype: DataType) {\n return shape[0] * shape[1] * util.bytesPerElement(dtype);\n }\n}\n\nif (device_util.isBrowser()) {\n ENGINE.registerBackend(\n 'webgl', () => new MathBackendWebGL(), 2 /* priority */);\n}\n\nfunction float32ToTypedArray(\n a: Float32Array, dtype: D): DataTypeMap[D] {\n if (dtype === 'float32' || dtype === 'complex64') {\n return a as DataTypeMap[D];\n } else if (dtype === 'int32' || dtype === 'bool') {\n const result = (dtype === 'int32') ? new Int32Array(a.length) :\n new Uint8Array(a.length);\n for (let i = 0; i < result.length; ++i) {\n result[i] = Math.round(a[i]);\n }\n return result as DataTypeMap[D];\n } else {\n throw new Error(`Unknown dtype ${dtype}`);\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {nearestDivisor} from '../util';\n\nimport {PARALLELIZE_THRESHOLD} from './reduce_util';\n\nexport interface SegOpInfo {\n windowSize: number;\n batchSize: number;\n inSize: number;\n numSegments: number;\n}\n\nexport function segOpComputeOptimalWindowSize(\n inSize: number, numSegments: number): number {\n let done = false;\n let res;\n\n if (inSize <= PARALLELIZE_THRESHOLD) {\n res = inSize;\n done = true;\n } else {\n res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));\n }\n\n while (!done) {\n if (res > numSegments || res === inSize) {\n done = true;\n } else {\n res = nearestDivisor(inSize, res + 1);\n }\n }\n return res;\n}\n\nexport function computeOutShape(\n aShape: number[], axis: number, numSegments: number): number[] {\n const outShape = [];\n const rank = aShape.length;\n for (let dim = 0; dim < rank; dim++) {\n if (dim !== axis) {\n outShape.push(aShape[dim]);\n } else {\n outShape.push(numSegments);\n }\n }\n return outShape;\n}\n\nexport interface GatherOpShapeInfo {\n batchSize: number;\n sliceSize: number;\n dimSize: number;\n outputShape: number[];\n}\nexport function collectGatherOpShapeInfo(\n x: Tensor, indices: Tensor, axis: number): GatherOpShapeInfo {\n const dimSize = x.shape[axis];\n\n const outputShape: number[] = [];\n let batchSize = 1;\n let sliceSize = 1;\n for (let i = 0; i < axis; i++) {\n outputShape.push(x.shape[i]);\n batchSize *= x.shape[i];\n }\n\n for (let i = 0; i < indices.rank; i++) {\n outputShape.push(indices.shape[i]);\n }\n\n for (let i = axis + 1; i < x.rank; i++) {\n outputShape.push(x.shape[i]);\n sliceSize *= x.shape[i];\n }\n\n return {batchSize, sliceSize, dimSize, outputShape};\n}\n","// A port of an algorithm by Johannes Baagøe , 2010\n// http://baagoe.com/en/RandomMusings/javascript/\n// https://github.com/nquinlan/better-random-numbers-for-javascript-mirror\n// Original work is under MIT license -\n\n// Copyright (C) 2010 by Johannes Baagøe \n//\n// Permission is hereby granted, free of charge, to any person obtaining a copy\n// of this software and associated documentation files (the \"Software\"), to deal\n// in the Software without restriction, including without limitation the rights\n// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n// copies of the Software, and to permit persons to whom the Software is\n// furnished to do so, subject to the following conditions:\n// \n// The above copyright notice and this permission notice shall be included in\n// all copies or substantial portions of the Software.\n// \n// THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n// THE SOFTWARE.\n\n\n\n(function(global, module, define) {\n\nfunction Alea(seed) {\n var me = this, mash = Mash();\n\n me.next = function() {\n var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32\n me.s0 = me.s1;\n me.s1 = me.s2;\n return me.s2 = t - (me.c = t | 0);\n };\n\n // Apply the seeding algorithm from Baagoe.\n me.c = 1;\n me.s0 = mash(' ');\n me.s1 = mash(' ');\n me.s2 = mash(' ');\n me.s0 -= mash(seed);\n if (me.s0 < 0) { me.s0 += 1; }\n me.s1 -= mash(seed);\n if (me.s1 < 0) { me.s1 += 1; }\n me.s2 -= mash(seed);\n if (me.s2 < 0) { me.s2 += 1; }\n mash = null;\n}\n\nfunction copy(f, t) {\n t.c = f.c;\n t.s0 = f.s0;\n t.s1 = f.s1;\n t.s2 = f.s2;\n return t;\n}\n\nfunction impl(seed, opts) {\n var xg = new Alea(seed),\n state = opts && opts.state,\n prng = xg.next;\n prng.int32 = function() { return (xg.next() * 0x100000000) | 0; }\n prng.double = function() {\n return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53\n };\n prng.quick = prng;\n if (state) {\n if (typeof(state) == 'object') copy(state, xg);\n prng.state = function() { return copy(xg, {}); }\n }\n return prng;\n}\n\nfunction Mash() {\n var n = 0xefc8249d;\n\n var mash = function(data) {\n data = data.toString();\n for (var i = 0; i < data.length; i++) {\n n += data.charCodeAt(i);\n var h = 0.02519603282416938 * n;\n n = h >>> 0;\n h -= n;\n h *= n;\n n = h >>> 0;\n h -= n;\n n += h * 0x100000000; // 2^32\n }\n return (n >>> 0) * 2.3283064365386963e-10; // 2^-32\n };\n\n return mash;\n}\n\n\nif (module && module.exports) {\n module.exports = impl;\n} else if (define && define.amd) {\n define(function() { return impl; });\n} else {\n this.alea = impl;\n}\n\n})(\n this,\n (typeof module) == 'object' && module, // present in node.js\n (typeof define) == 'function' && define // present with an AMD loader\n);\n\n\n","// A Javascript implementaion of the \"xor128\" prng algorithm by\n// George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper\n\n(function(global, module, define) {\n\nfunction XorGen(seed) {\n var me = this, strseed = '';\n\n me.x = 0;\n me.y = 0;\n me.z = 0;\n me.w = 0;\n\n // Set up generator function.\n me.next = function() {\n var t = me.x ^ (me.x << 11);\n me.x = me.y;\n me.y = me.z;\n me.z = me.w;\n return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);\n };\n\n if (seed === (seed | 0)) {\n // Integer seed.\n me.x = seed;\n } else {\n // String seed.\n strseed += seed;\n }\n\n // Mix in string seed, then discard an initial batch of 64 values.\n for (var k = 0; k < strseed.length + 64; k++) {\n me.x ^= strseed.charCodeAt(k) | 0;\n me.next();\n }\n}\n\nfunction copy(f, t) {\n t.x = f.x;\n t.y = f.y;\n t.z = f.z;\n t.w = f.w;\n return t;\n}\n\nfunction impl(seed, opts) {\n var xg = new XorGen(seed),\n state = opts && opts.state,\n prng = function() { return (xg.next() >>> 0) / 0x100000000; };\n prng.double = function() {\n do {\n var top = xg.next() >>> 11,\n bot = (xg.next() >>> 0) / 0x100000000,\n result = (top + bot) / (1 << 21);\n } while (result === 0);\n return result;\n };\n prng.int32 = xg.next;\n prng.quick = prng;\n if (state) {\n if (typeof(state) == 'object') copy(state, xg);\n prng.state = function() { return copy(xg, {}); }\n }\n return prng;\n}\n\nif (module && module.exports) {\n module.exports = impl;\n} else if (define && define.amd) {\n define(function() { return impl; });\n} else {\n this.xor128 = impl;\n}\n\n})(\n this,\n (typeof module) == 'object' && module, // present in node.js\n (typeof define) == 'function' && define // present with an AMD loader\n);\n\n\n","// A Javascript implementaion of the \"xorwow\" prng algorithm by\n// George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper\n\n(function(global, module, define) {\n\nfunction XorGen(seed) {\n var me = this, strseed = '';\n\n // Set up generator function.\n me.next = function() {\n var t = (me.x ^ (me.x >>> 2));\n me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;\n return (me.d = (me.d + 362437 | 0)) +\n (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;\n };\n\n me.x = 0;\n me.y = 0;\n me.z = 0;\n me.w = 0;\n me.v = 0;\n\n if (seed === (seed | 0)) {\n // Integer seed.\n me.x = seed;\n } else {\n // String seed.\n strseed += seed;\n }\n\n // Mix in string seed, then discard an initial batch of 64 values.\n for (var k = 0; k < strseed.length + 64; k++) {\n me.x ^= strseed.charCodeAt(k) | 0;\n if (k == strseed.length) {\n me.d = me.x << 10 ^ me.x >>> 4;\n }\n me.next();\n }\n}\n\nfunction copy(f, t) {\n t.x = f.x;\n t.y = f.y;\n t.z = f.z;\n t.w = f.w;\n t.v = f.v;\n t.d = f.d;\n return t;\n}\n\nfunction impl(seed, opts) {\n var xg = new XorGen(seed),\n state = opts && opts.state,\n prng = function() { return (xg.next() >>> 0) / 0x100000000; };\n prng.double = function() {\n do {\n var top = xg.next() >>> 11,\n bot = (xg.next() >>> 0) / 0x100000000,\n result = (top + bot) / (1 << 21);\n } while (result === 0);\n return result;\n };\n prng.int32 = xg.next;\n prng.quick = prng;\n if (state) {\n if (typeof(state) == 'object') copy(state, xg);\n prng.state = function() { return copy(xg, {}); }\n }\n return prng;\n}\n\nif (module && module.exports) {\n module.exports = impl;\n} else if (define && define.amd) {\n define(function() { return impl; });\n} else {\n this.xorwow = impl;\n}\n\n})(\n this,\n (typeof module) == 'object' && module, // present in node.js\n (typeof define) == 'function' && define // present with an AMD loader\n);\n\n\n","// A Javascript implementaion of the \"xorshift7\" algorithm by\n// François Panneton and Pierre L'ecuyer:\n// \"On the Xorgshift Random Number Generators\"\n// http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf\n\n(function(global, module, define) {\n\nfunction XorGen(seed) {\n var me = this;\n\n // Set up generator function.\n me.next = function() {\n // Update xor generator.\n var X = me.x, i = me.i, t, v, w;\n t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);\n t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);\n t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);\n t = X[(i + 4) & 7]; v ^= t ^ (t << 7);\n t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);\n X[i] = v;\n me.i = (i + 1) & 7;\n return v;\n };\n\n function init(me, seed) {\n var j, w, X = [];\n\n if (seed === (seed | 0)) {\n // Seed state array using a 32-bit integer.\n w = X[0] = seed;\n } else {\n // Seed state using a string.\n seed = '' + seed;\n for (j = 0; j < seed.length; ++j) {\n X[j & 7] = (X[j & 7] << 15) ^\n (seed.charCodeAt(j) + X[(j + 1) & 7] << 13);\n }\n }\n // Enforce an array length of 8, not all zeroes.\n while (X.length < 8) X.push(0);\n for (j = 0; j < 8 && X[j] === 0; ++j);\n if (j == 8) w = X[7] = -1; else w = X[j];\n\n me.x = X;\n me.i = 0;\n\n // Discard an initial 256 values.\n for (j = 256; j > 0; --j) {\n me.next();\n }\n }\n\n init(me, seed);\n}\n\nfunction copy(f, t) {\n t.x = f.x.slice();\n t.i = f.i;\n return t;\n}\n\nfunction impl(seed, opts) {\n if (seed == null) seed = +(new Date);\n var xg = new XorGen(seed),\n state = opts && opts.state,\n prng = function() { return (xg.next() >>> 0) / 0x100000000; };\n prng.double = function() {\n do {\n var top = xg.next() >>> 11,\n bot = (xg.next() >>> 0) / 0x100000000,\n result = (top + bot) / (1 << 21);\n } while (result === 0);\n return result;\n };\n prng.int32 = xg.next;\n prng.quick = prng;\n if (state) {\n if (state.x) copy(state, xg);\n prng.state = function() { return copy(xg, {}); }\n }\n return prng;\n}\n\nif (module && module.exports) {\n module.exports = impl;\n} else if (define && define.amd) {\n define(function() { return impl; });\n} else {\n this.xorshift7 = impl;\n}\n\n})(\n this,\n (typeof module) == 'object' && module, // present in node.js\n (typeof define) == 'function' && define // present with an AMD loader\n);\n\n","// A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.\n//\n// This fast non-cryptographic random number generator is designed for\n// use in Monte-Carlo algorithms. It combines a long-period xorshift\n// generator with a Weyl generator, and it passes all common batteries\n// of stasticial tests for randomness while consuming only a few nanoseconds\n// for each prng generated. For background on the generator, see Brent's\n// paper: \"Some long-period random number generators using shifts and xors.\"\n// http://arxiv.org/pdf/1004.3115v1.pdf\n//\n// Usage:\n//\n// var xor4096 = require('xor4096');\n// random = xor4096(1); // Seed with int32 or string.\n// assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.\n// assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.\n//\n// For nonzero numeric keys, this impelementation provides a sequence\n// identical to that by Brent's xorgens 3 implementaion in C. This\n// implementation also provides for initalizing the generator with\n// string seeds, or for saving and restoring the state of the generator.\n//\n// On Chrome, this prng benchmarks about 2.1 times slower than\n// Javascript's built-in Math.random().\n\n(function(global, module, define) {\n\nfunction XorGen(seed) {\n var me = this;\n\n // Set up generator function.\n me.next = function() {\n var w = me.w,\n X = me.X, i = me.i, t, v;\n // Update Weyl generator.\n me.w = w = (w + 0x61c88647) | 0;\n // Update xor generator.\n v = X[(i + 34) & 127];\n t = X[i = ((i + 1) & 127)];\n v ^= v << 13;\n t ^= t << 17;\n v ^= v >>> 15;\n t ^= t >>> 12;\n // Update Xor generator array state.\n v = X[i] = v ^ t;\n me.i = i;\n // Result is the combination.\n return (v + (w ^ (w >>> 16))) | 0;\n };\n\n function init(me, seed) {\n var t, v, i, j, w, X = [], limit = 128;\n if (seed === (seed | 0)) {\n // Numeric seeds initialize v, which is used to generates X.\n v = seed;\n seed = null;\n } else {\n // String seeds are mixed into v and X one character at a time.\n seed = seed + '\\0';\n v = 0;\n limit = Math.max(limit, seed.length);\n }\n // Initialize circular array and weyl value.\n for (i = 0, j = -32; j < limit; ++j) {\n // Put the unicode characters into the array, and shuffle them.\n if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);\n // After 32 shuffles, take v as the starting w value.\n if (j === 0) w = v;\n v ^= v << 10;\n v ^= v >>> 15;\n v ^= v << 4;\n v ^= v >>> 13;\n if (j >= 0) {\n w = (w + 0x61c88647) | 0; // Weyl.\n t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array.\n i = (0 == t) ? i + 1 : 0; // Count zeroes.\n }\n }\n // We have detected all zeroes; make the key nonzero.\n if (i >= 128) {\n X[(seed && seed.length || 0) & 127] = -1;\n }\n // Run the generator 512 times to further mix the state before using it.\n // Factoring this as a function slows the main generator, so it is just\n // unrolled here. The weyl generator is not advanced while warming up.\n i = 127;\n for (j = 4 * 128; j > 0; --j) {\n v = X[(i + 34) & 127];\n t = X[i = ((i + 1) & 127)];\n v ^= v << 13;\n t ^= t << 17;\n v ^= v >>> 15;\n t ^= t >>> 12;\n X[i] = v ^ t;\n }\n // Storing state as object members is faster than using closure variables.\n me.w = w;\n me.X = X;\n me.i = i;\n }\n\n init(me, seed);\n}\n\nfunction copy(f, t) {\n t.i = f.i;\n t.w = f.w;\n t.X = f.X.slice();\n return t;\n};\n\nfunction impl(seed, opts) {\n if (seed == null) seed = +(new Date);\n var xg = new XorGen(seed),\n state = opts && opts.state,\n prng = function() { return (xg.next() >>> 0) / 0x100000000; };\n prng.double = function() {\n do {\n var top = xg.next() >>> 11,\n bot = (xg.next() >>> 0) / 0x100000000,\n result = (top + bot) / (1 << 21);\n } while (result === 0);\n return result;\n };\n prng.int32 = xg.next;\n prng.quick = prng;\n if (state) {\n if (state.X) copy(state, xg);\n prng.state = function() { return copy(xg, {}); }\n }\n return prng;\n}\n\nif (module && module.exports) {\n module.exports = impl;\n} else if (define && define.amd) {\n define(function() { return impl; });\n} else {\n this.xor4096 = impl;\n}\n\n})(\n this, // window object or global\n (typeof module) == 'object' && module, // present in node.js\n (typeof define) == 'function' && define // present with an AMD loader\n);\n","// A Javascript implementaion of the \"Tyche-i\" prng algorithm by\n// Samuel Neves and Filipe Araujo.\n// See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf\n\n(function(global, module, define) {\n\nfunction XorGen(seed) {\n var me = this, strseed = '';\n\n // Set up generator function.\n me.next = function() {\n var b = me.b, c = me.c, d = me.d, a = me.a;\n b = (b << 25) ^ (b >>> 7) ^ c;\n c = (c - d) | 0;\n d = (d << 24) ^ (d >>> 8) ^ a;\n a = (a - b) | 0;\n me.b = b = (b << 20) ^ (b >>> 12) ^ c;\n me.c = c = (c - d) | 0;\n me.d = (d << 16) ^ (c >>> 16) ^ a;\n return me.a = (a - b) | 0;\n };\n\n /* The following is non-inverted tyche, which has better internal\n * bit diffusion, but which is about 25% slower than tyche-i in JS.\n me.next = function() {\n var a = me.a, b = me.b, c = me.c, d = me.d;\n a = (me.a + me.b | 0) >>> 0;\n d = me.d ^ a; d = d << 16 ^ d >>> 16;\n c = me.c + d | 0;\n b = me.b ^ c; b = b << 12 ^ d >>> 20;\n me.a = a = a + b | 0;\n d = d ^ a; me.d = d = d << 8 ^ d >>> 24;\n me.c = c = c + d | 0;\n b = b ^ c;\n return me.b = (b << 7 ^ b >>> 25);\n }\n */\n\n me.a = 0;\n me.b = 0;\n me.c = 2654435769 | 0;\n me.d = 1367130551;\n\n if (seed === Math.floor(seed)) {\n // Integer seed.\n me.a = (seed / 0x100000000) | 0;\n me.b = seed | 0;\n } else {\n // String seed.\n strseed += seed;\n }\n\n // Mix in string seed, then discard an initial batch of 64 values.\n for (var k = 0; k < strseed.length + 20; k++) {\n me.b ^= strseed.charCodeAt(k) | 0;\n me.next();\n }\n}\n\nfunction copy(f, t) {\n t.a = f.a;\n t.b = f.b;\n t.c = f.c;\n t.d = f.d;\n return t;\n};\n\nfunction impl(seed, opts) {\n var xg = new XorGen(seed),\n state = opts && opts.state,\n prng = function() { return (xg.next() >>> 0) / 0x100000000; };\n prng.double = function() {\n do {\n var top = xg.next() >>> 11,\n bot = (xg.next() >>> 0) / 0x100000000,\n result = (top + bot) / (1 << 21);\n } while (result === 0);\n return result;\n };\n prng.int32 = xg.next;\n prng.quick = prng;\n if (state) {\n if (typeof(state) == 'object') copy(state, xg);\n prng.state = function() { return copy(xg, {}); }\n }\n return prng;\n}\n\nif (module && module.exports) {\n module.exports = impl;\n} else if (define && define.amd) {\n define(function() { return impl; });\n} else {\n this.tychei = impl;\n}\n\n})(\n this,\n (typeof module) == 'object' && module, // present in node.js\n (typeof define) == 'function' && define // present with an AMD loader\n);\n\n\n","/*\nCopyright 2014 David Bau.\n\nPermission is hereby granted, free of charge, to any person obtaining\na copy of this software and associated documentation files (the\n\"Software\"), to deal in the Software without restriction, including\nwithout limitation the rights to use, copy, modify, merge, publish,\ndistribute, sublicense, and/or sell copies of the Software, and to\npermit persons to whom the Software is furnished to do so, subject to\nthe following conditions:\n\nThe above copyright notice and this permission notice shall be\nincluded in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\nEXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\nMERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\nIN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\nCLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\nTORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\nSOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\n*/\n\n(function (pool, math) {\n//\n// The following constants are related to IEEE 754 limits.\n//\nvar global = this,\n width = 256, // each RC4 output is 0 <= x < 256\n chunks = 6, // at least six RC4 outputs for each double\n digits = 52, // there are 52 significant digits in a double\n rngname = 'random', // rngname: name for Math.random and Math.seedrandom\n startdenom = math.pow(width, chunks),\n significance = math.pow(2, digits),\n overflow = significance * 2,\n mask = width - 1,\n nodecrypto; // node.js crypto module, initialized at the bottom.\n\n//\n// seedrandom()\n// This is the seedrandom function described above.\n//\nfunction seedrandom(seed, options, callback) {\n var key = [];\n options = (options == true) ? { entropy: true } : (options || {});\n\n // Flatten the seed string or build one from local entropy if needed.\n var shortseed = mixkey(flatten(\n options.entropy ? [seed, tostring(pool)] :\n (seed == null) ? autoseed() : seed, 3), key);\n\n // Use the seed to initialize an ARC4 generator.\n var arc4 = new ARC4(key);\n\n // This function returns a random double in [0, 1) that contains\n // randomness in every bit of the mantissa of the IEEE 754 value.\n var prng = function() {\n var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48\n d = startdenom, // and denominator d = 2 ^ 48.\n x = 0; // and no 'extra last byte'.\n while (n < significance) { // Fill up all significant digits by\n n = (n + x) * width; // shifting numerator and\n d *= width; // denominator and generating a\n x = arc4.g(1); // new least-significant-byte.\n }\n while (n >= overflow) { // To avoid rounding up, before adding\n n /= 2; // last byte, shift everything\n d /= 2; // right using integer math until\n x >>>= 1; // we have exactly the desired bits.\n }\n return (n + x) / d; // Form the number within [0, 1).\n };\n\n prng.int32 = function() { return arc4.g(4) | 0; }\n prng.quick = function() { return arc4.g(4) / 0x100000000; }\n prng.double = prng;\n\n // Mix the randomness into accumulated entropy.\n mixkey(tostring(arc4.S), pool);\n\n // Calling convention: what to return as a function of prng, seed, is_math.\n return (options.pass || callback ||\n function(prng, seed, is_math_call, state) {\n if (state) {\n // Load the arc4 state from the given state if it has an S array.\n if (state.S) { copy(state, arc4); }\n // Only provide the .state method if requested via options.state.\n prng.state = function() { return copy(arc4, {}); }\n }\n\n // If called as a method of Math (Math.seedrandom()), mutate\n // Math.random because that is how seedrandom.js has worked since v1.0.\n if (is_math_call) { math[rngname] = prng; return seed; }\n\n // Otherwise, it is a newer calling convention, so return the\n // prng directly.\n else return prng;\n })(\n prng,\n shortseed,\n 'global' in options ? options.global : (this == math),\n options.state);\n}\nmath['seed' + rngname] = seedrandom;\n\n//\n// ARC4\n//\n// An ARC4 implementation. The constructor takes a key in the form of\n// an array of at most (width) integers that should be 0 <= x < (width).\n//\n// The g(count) method returns a pseudorandom integer that concatenates\n// the next (count) outputs from ARC4. Its return value is a number x\n// that is in the range 0 <= x < (width ^ count).\n//\nfunction ARC4(key) {\n var t, keylen = key.length,\n me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];\n\n // The empty key [] is treated as [0].\n if (!keylen) { key = [keylen++]; }\n\n // Set up S using the standard key scheduling algorithm.\n while (i < width) {\n s[i] = i++;\n }\n for (i = 0; i < width; i++) {\n s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];\n s[j] = t;\n }\n\n // The \"g\" method returns the next (count) outputs as one number.\n (me.g = function(count) {\n // Using instance members instead of closure state nearly doubles speed.\n var t, r = 0,\n i = me.i, j = me.j, s = me.S;\n while (count--) {\n t = s[i = mask & (i + 1)];\n r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];\n }\n me.i = i; me.j = j;\n return r;\n // For robust unpredictability, the function call below automatically\n // discards an initial batch of values. This is called RC4-drop[256].\n // See http://google.com/search?q=rsa+fluhrer+response&btnI\n })(width);\n}\n\n//\n// copy()\n// Copies internal state of ARC4 to or from a plain object.\n//\nfunction copy(f, t) {\n t.i = f.i;\n t.j = f.j;\n t.S = f.S.slice();\n return t;\n};\n\n//\n// flatten()\n// Converts an object tree to nested arrays of strings.\n//\nfunction flatten(obj, depth) {\n var result = [], typ = (typeof obj), prop;\n if (depth && typ == 'object') {\n for (prop in obj) {\n try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}\n }\n }\n return (result.length ? result : typ == 'string' ? obj : obj + '\\0');\n}\n\n//\n// mixkey()\n// Mixes a string seed into a key that is an array of integers, and\n// returns a shortened string seed that is equivalent to the result key.\n//\nfunction mixkey(seed, key) {\n var stringseed = seed + '', smear, j = 0;\n while (j < stringseed.length) {\n key[mask & j] =\n mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));\n }\n return tostring(key);\n}\n\n//\n// autoseed()\n// Returns an object for autoseeding, using window.crypto and Node crypto\n// module if available.\n//\nfunction autoseed() {\n try {\n var out;\n if (nodecrypto && (out = nodecrypto.randomBytes)) {\n // The use of 'out' to remember randomBytes makes tight minified code.\n out = out(width);\n } else {\n out = new Uint8Array(width);\n (global.crypto || global.msCrypto).getRandomValues(out);\n }\n return tostring(out);\n } catch (e) {\n var browser = global.navigator,\n plugins = browser && browser.plugins;\n return [+new Date, global, plugins, global.screen, tostring(pool)];\n }\n}\n\n//\n// tostring()\n// Converts an array of charcodes to a string\n//\nfunction tostring(a) {\n return String.fromCharCode.apply(0, a);\n}\n\n//\n// When seedrandom.js is loaded, we immediately mix a few bits\n// from the built-in RNG into the entropy pool. Because we do\n// not want to interfere with deterministic PRNG state later,\n// seedrandom will not call math.random on its own again after\n// initialization.\n//\nmixkey(math.random(), pool);\n\n//\n// Nodejs and AMD support: export the implementation as a module using\n// either convention.\n//\nif ((typeof module) == 'object' && module.exports) {\n module.exports = seedrandom;\n // When in node.js, try using crypto package for autoseeding.\n try {\n nodecrypto = require('crypto');\n } catch (ex) {}\n} else if ((typeof define) == 'function' && define.amd) {\n define(function() { return seedrandom; });\n}\n\n// End anonymous scope, and pass initial values.\n})(\n [], // pool: entropy pool starts empty\n Math // math: package containing random, pow, and seedrandom\n);\n","// A library of seedable RNGs implemented in Javascript.\n//\n// Usage:\n//\n// var seedrandom = require('seedrandom');\n// var random = seedrandom(1); // or any seed.\n// var x = random(); // 0 <= x < 1. Every bit is random.\n// var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.\n\n// alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.\n// Period: ~2^116\n// Reported to pass all BigCrush tests.\nvar alea = require('./lib/alea');\n\n// xor128, a pure xor-shift generator by George Marsaglia.\n// Period: 2^128-1.\n// Reported to fail: MatrixRank and LinearComp.\nvar xor128 = require('./lib/xor128');\n\n// xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.\n// Period: 2^192-2^32\n// Reported to fail: CollisionOver, SimpPoker, and LinearComp.\nvar xorwow = require('./lib/xorwow');\n\n// xorshift7, by François Panneton and Pierre L'ecuyer, takes\n// a different approach: it adds robustness by allowing more shifts\n// than Marsaglia's original three. It is a 7-shift generator\n// with 256 bits, that passes BigCrush with no systmatic failures.\n// Period 2^256-1.\n// No systematic BigCrush failures reported.\nvar xorshift7 = require('./lib/xorshift7');\n\n// xor4096, by Richard Brent, is a 4096-bit xor-shift with a\n// very long period that also adds a Weyl generator. It also passes\n// BigCrush with no systematic failures. Its long period may\n// be useful if you have many generators and need to avoid\n// collisions.\n// Period: 2^4128-2^32.\n// No systematic BigCrush failures reported.\nvar xor4096 = require('./lib/xor4096');\n\n// Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random\n// number generator derived from ChaCha, a modern stream cipher.\n// https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf\n// Period: ~2^127\n// No systematic BigCrush failures reported.\nvar tychei = require('./lib/tychei');\n\n// The original ARC4-based prng included in this library.\n// Period: ~2^1600\nvar sr = require('./seedrandom');\n\nsr.alea = alea;\nsr.xor128 = xor128;\nsr.xorwow = xorwow;\nsr.xorshift7 = xorshift7;\nsr.xor4096 = xor4096;\nsr.tychei = tychei;\n\nmodule.exports = sr;\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {AddNInputs} from '../kernel_names';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {op} from './operation';\n\n/**\n * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.\n *\n * ```js\n * const a = tf.tensor1d([1, 2]);\n * const b = tf.tensor1d([3, 4]);\n * const c = tf.tensor1d([5, 6]);\n *\n * tf.addN([a, b, c]).print();\n * ```\n * @param tensors A list of tensors with the same shape and dtype.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction addN_(tensors: Array): T {\n util.assert(\n Array.isArray(tensors),\n () => 'The argument passed to tf.addN() must be a list of tensors');\n util.assert(\n tensors.length >= 1,\n () => `Must pass at least one tensor to tf.addN(), but got ` +\n `${tensors.length}`);\n\n const $tensors =\n tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'addN'));\n\n const firstTensor = $tensors[0];\n $tensors.forEach(t => {\n if (t.dtype !== firstTensor.dtype) {\n throw new Error(\n 'All tensors passed to tf.addN() must have the same dtype');\n }\n });\n\n $tensors.forEach(t => {\n if (!util.arraysEqual(t.shape, firstTensor.shape)) {\n throw new Error(\n 'All tensors passed to tf.addN() must have the same shape');\n }\n });\n\n const forward: ForwardFunc = (backend, save) =>\n backend.addN($tensors);\n\n const inputs: AddNInputs = $tensors;\n\n return ENGINE.runKernelFunc(\n forward, inputs as {} as NamedTensorMap, null /* grad */,\n 'AddN') as T;\n}\n\nexport const addN = op({addN_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {deprecationWarn} from '../globals';\nimport {Tensor, Tensor4D} from '../tensor';\nimport {Rank} from '../types';\n\nexport function warnDeprecation(): void {\n deprecationWarn(\n 'tf.batchNormalization() is going away. ' +\n 'Use tf.batchNorm() instead, and note the positional argument change ' +\n 'of scale, offset, and varianceEpsilon');\n}\n\nexport function xAs4D(x: Tensor) {\n let x4D: Tensor4D;\n if (x.rank === 0 || x.rank === 1) {\n x4D = x.as4D(1, 1, 1, x.size);\n } else if (x.rank === 2) {\n x4D = x.as4D(1, 1, x.shape[0], x.shape[1]);\n } else if (x.rank === 3) {\n x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);\n } else {\n x4D = x as Tensor4D;\n }\n\n return x4D;\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {FusedBatchNormAttrs, FusedBatchNormInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor, Tensor1D, Tensor4D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {Rank, TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {warnDeprecation, xAs4D} from './batchnorm_util';\nimport {op} from './operation';\n\n/**\n * @deprecated Please use `tf.batchNorm` instead and note the positional\n * argument change of scale, offset, and varianceEpsilon.\n */\nfunction batchNormalization_(\n x: Tensor|TensorLike, mean: Tensor|Tensor1D|TensorLike,\n variance: Tensor|Tensor1D|TensorLike, varianceEpsilon = .001,\n scale?: Tensor|Tensor1D|TensorLike,\n offset?: Tensor|Tensor1D|TensorLike): Tensor {\n warnDeprecation();\n return batchNorm_(x, mean, variance, offset, scale, varianceEpsilon);\n}\n\n/**\n * Batch normalization.\n *\n * As described in\n * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).\n *\n * Mean, variance, scale, and offset can be of two shapes:\n * - The same shape as the input.\n * - In the common case, the depth dimension is the last dimension of x, so\n * the values would be an `tf.Tensor1D` of shape [depth].\n *\n * Also available are stricter rank-specific methods with the same signature\n * as this method that assert that parameters passed are of given rank\n * - `tf.batchNorm2d`\n * - `tf.batchNorm3d`\n * - `tf.batchNorm4d`\n *\n * @param x The input Tensor.\n * @param mean A mean Tensor.\n * @param variance A variance Tensor.\n * @param offset An offset Tensor.\n * @param scale A scale Tensor.\n * @param varianceEpsilon A small float number to avoid dividing by 0.\n */\n/** @doc {heading: 'Operations', subheading: 'Normalization'} */\nfunction batchNorm_(\n x: Tensor|TensorLike, mean: Tensor|Tensor1D|TensorLike,\n variance: Tensor|Tensor1D|TensorLike,\n offset?: Tensor|Tensor1D|TensorLike,\n scale?: Tensor|Tensor1D|TensorLike,\n varianceEpsilon?: number): Tensor {\n if (varianceEpsilon == null) {\n varianceEpsilon = 0.001;\n }\n const $x = convertToTensor(x, 'x', 'batchNorm');\n const $mean = convertToTensor(mean, 'mean', 'batchNorm');\n const $variance = convertToTensor(variance, 'variance', 'batchNorm');\n let $scale: Tensor|Tensor1D;\n if (scale != null) {\n $scale = convertToTensor(scale, 'scale', 'batchNorm');\n }\n let $offset: Tensor|Tensor1D;\n if (offset != null) {\n $offset = convertToTensor(offset, 'offset', 'batchNorm');\n }\n\n util.assert(\n $mean.rank === $variance.rank,\n () => 'Batch normalization gradient requires mean and variance to have ' +\n 'equal ranks.');\n util.assert(\n $offset == null || $mean.rank === $offset.rank,\n () => 'Batch normalization gradient requires mean and offset to have ' +\n 'equal ranks.');\n util.assert(\n $scale == null || $mean.rank === $scale.rank,\n () => 'Batch normalization gradient requires mean and scale to have ' +\n 'equal ranks.');\n\n const forward: ForwardFunc = (backend, save) => {\n const x4D: Tensor4D = xAs4D($x);\n\n const res = backend.batchNormalization(\n x4D, as1DOr4D($mean), as1DOr4D($variance), varianceEpsilon,\n as1DOr4D($scale), as1DOr4D($offset));\n\n save([$x, $mean, $variance, $scale]);\n\n return res;\n };\n\n const inputs: FusedBatchNormInputs =\n {x: $x, scale: $scale, offset: $offset, mean: $mean, variance: $variance};\n\n const attrs: FusedBatchNormAttrs = {varianceEpsilon};\n\n const res = ENGINE.runKernelFunc(\n forward, inputs as {} as NamedTensorMap, null /* gradient */,\n 'FusedBatchNorm', attrs as {} as NamedAttrMap);\n\n return res.reshape($x.shape);\n}\n\nfunction as1DOr4D(x: Tensor): Tensor4D|Tensor1D {\n if (x == null) {\n return null;\n }\n if (x.rank === 0) {\n return x.as1D();\n } else if (x.rank === 1) {\n return x as Tensor1D;\n } else if (x.rank === 2) {\n return x.as4D(1, 1, x.shape[0], x.shape[1]);\n } else if (x.rank === 3) {\n return x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);\n }\n return x as Tensor4D;\n}\n\n// todo(yassogba): Remove batchNormalization since it is deprecated.\nexport const batchNormalization = op({batchNormalization_});\nexport const batchNorm = op({batchNorm_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor1D, Tensor2D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {batchNorm} from './batchnorm';\nimport {warnDeprecation} from './batchnorm_util';\nimport {op} from './operation';\n\n/**\n * Batch normalization, strictly for 2D. For the more relaxed version, see\n * `tf.batchNorm`.\n *\n * @param x The input Tensor.\n * @param mean A mean Tensor.\n * @param variance A variance Tensor.\n * @param offset An offset Tensor.\n * @param scale A scale Tensor.\n * @param varianceEpsilon A small float number to avoid dividing by 0.\n */\nfunction batchNorm2d_(\n x: Tensor2D|TensorLike, mean: Tensor2D|Tensor1D|TensorLike,\n variance: Tensor2D|Tensor1D|TensorLike,\n offset?: Tensor2D|Tensor1D|TensorLike, scale?: Tensor2D|Tensor1D|TensorLike,\n varianceEpsilon?: number): Tensor2D {\n const $x = convertToTensor(x, 'x', 'batchNorm');\n const $mean = convertToTensor(mean, 'mean', 'batchNorm');\n const $variance = convertToTensor(variance, 'variance', 'batchNorm');\n let $scale: Tensor2D|Tensor1D;\n if (scale != null) {\n $scale = convertToTensor(scale, 'scale', 'batchNorm');\n }\n let $offset: Tensor2D|Tensor1D;\n if (offset != null) {\n $offset = convertToTensor(offset, 'offset', 'batchNorm');\n }\n util.assert(\n $x.rank === 2,\n () => `Error in batchNorm3D: x must be rank 3 but got rank ` +\n `${$x.rank}.`);\n util.assert(\n $mean.rank === 2 || $mean.rank === 1,\n () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` +\n `got rank ${$mean.rank}.`);\n util.assert(\n $variance.rank === 2 || $variance.rank === 1,\n () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` +\n `but got rank ${$variance.rank}.`);\n if ($scale != null) {\n util.assert(\n $scale.rank === 2 || $scale.rank === 1,\n () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` +\n `but got rank ${$scale.rank}.`);\n }\n if ($offset != null) {\n util.assert(\n $offset.rank === 2 || $offset.rank === 1,\n () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` +\n `but got rank ${$offset.rank}.`);\n }\n\n return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);\n}\n\n/**\n * @deprecated Please use `tf.batchNorm2d` instead and note the positional\n * argument change of scale, offset, and varianceEpsilon.\n */\nfunction batchNormalization2d_(\n x: Tensor2D|TensorLike, mean: Tensor2D|Tensor1D|TensorLike,\n variance: Tensor2D|Tensor1D|TensorLike, varianceEpsilon = .001,\n scale?: Tensor2D|Tensor1D|TensorLike,\n offset?: Tensor2D|Tensor1D|TensorLike): Tensor2D {\n warnDeprecation();\n return batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon);\n}\n\n// todo(yassogba): Remove batchNormalization2d since it is deprecated.\nexport const batchNormalization2d = op({batchNormalization2d_});\nexport const batchNorm2d = op({batchNorm2d_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor1D, Tensor3D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {batchNorm} from './batchnorm';\nimport {warnDeprecation} from './batchnorm_util';\nimport {op} from './operation';\n\n/**\n * Batch normalization, strictly for 3D. For the more relaxed version, see\n * `tf.batchNorm`.\n *\n * @param x The input Tensor.\n * @param mean A mean Tensor.\n * @param variance A variance Tensor.\n * @param offset An offset Tensor.\n * @param scale A scale Tensor.\n * @param varianceEpsilon A small float number to avoid dividing by 0.\n */\nfunction batchNorm3d_(\n x: Tensor3D|TensorLike, mean: Tensor3D|Tensor1D|TensorLike,\n variance: Tensor3D|Tensor1D|TensorLike,\n offset?: Tensor3D|Tensor1D|TensorLike, scale?: Tensor3D|Tensor1D|TensorLike,\n varianceEpsilon?: number): Tensor3D {\n const $x = convertToTensor(x, 'x', 'batchNorm');\n const $mean = convertToTensor(mean, 'mean', 'batchNorm');\n const $variance = convertToTensor(variance, 'variance', 'batchNorm');\n let $scale: Tensor3D|Tensor1D;\n if (scale != null) {\n $scale = convertToTensor(scale, 'scale', 'batchNorm');\n }\n let $offset: Tensor3D|Tensor1D;\n if (offset != null) {\n $offset = convertToTensor(offset, 'offset', 'batchNorm');\n }\n util.assert(\n $x.rank === 3,\n () => `Error in batchNorm3D: x must be rank 3 but got rank ` +\n `${$x.rank}.`);\n util.assert(\n $mean.rank === 3 || $mean.rank === 1,\n () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but ` +\n `got rank ${$mean.rank}.`);\n util.assert(\n $variance.rank === 3 || $variance.rank === 1,\n () => `Error in batchNorm3D: variance must be rank 3 or rank 1 ` +\n `but got rank ${$variance.rank}.`);\n if ($scale != null) {\n util.assert(\n $scale.rank === 3 || $scale.rank === 1,\n () => `Error in batchNorm3D: scale must be rank 3 or rank 1 ` +\n `but got rank ${$scale.rank}.`);\n }\n if ($offset != null) {\n util.assert(\n $offset.rank === 3 || $offset.rank === 1,\n () => `Error in batchNorm3D: offset must be rank 3 or rank 1 ` +\n `but got rank ${$offset.rank}.`);\n }\n\n return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);\n}\n\n/**\n * @deprecated Please use `tf.batchNorm3d` instead and note the positional\n * argument change of scale, offset, and varianceEpsilon.\n */\nfunction batchNormalization3d_(\n x: Tensor3D|TensorLike, mean: Tensor3D|Tensor1D|TensorLike,\n variance: Tensor3D|Tensor1D|TensorLike, varianceEpsilon = .001,\n scale?: Tensor3D|Tensor1D|TensorLike,\n offset?: Tensor3D|Tensor1D|TensorLike): Tensor3D {\n warnDeprecation();\n return batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon);\n}\n\n// todo(yassogba): Remove batchNormalization3d since it is deprecated.\nexport const batchNormalization3d = op({batchNormalization3d_});\nexport const batchNorm3d = op({batchNorm3d_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor1D, Tensor4D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {batchNorm} from './batchnorm';\nimport {warnDeprecation} from './batchnorm_util';\nimport {op} from './operation';\n\n/**\n * Batch normalization, strictly for 4D. For the more relaxed version, see\n * `tf.batchNorm`.\n *\n * @param x The input Tensor.\n * @param mean A mean Tensor.\n * @param variance A variance Tensor.\n * @param offset An offset Tensor.\n * @param scale A scale Tensor.\n * @param varianceEpsilon A small float number to avoid dividing by 0.\n */\nfunction batchNorm4d_(\n x: Tensor4D|TensorLike, mean: Tensor4D|Tensor1D|TensorLike,\n variance: Tensor4D|Tensor1D|TensorLike,\n offset?: Tensor4D|Tensor1D|TensorLike, scale?: Tensor4D|Tensor1D|TensorLike,\n varianceEpsilon?: number): Tensor4D {\n const $x = convertToTensor(x, 'x', 'batchNorm');\n const $mean = convertToTensor(mean, 'mean', 'batchNorm');\n const $variance = convertToTensor(variance, 'variance', 'batchNorm');\n let $scale: Tensor4D|Tensor1D;\n if (scale != null) {\n $scale = convertToTensor(scale, 'scale', 'batchNorm');\n }\n let $offset: Tensor4D|Tensor1D;\n if (offset != null) {\n $offset = convertToTensor(offset, 'offset', 'batchNorm');\n }\n util.assert(\n $x.rank === 4,\n () => `Error in batchNorm4D: x must be rank 4 but got rank ` +\n `${$x.rank}.`);\n util.assert(\n $mean.rank === 4 || $mean.rank === 1,\n () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but ` +\n `got rank ${$mean.rank}.`);\n util.assert(\n $variance.rank === 4 || $variance.rank === 1,\n () => `Error in batchNorm4D: variance must be rank 4 or rank 1 ` +\n `but got rank ${$variance.rank}.`);\n if ($scale != null) {\n util.assert(\n $scale.rank === 4 || $scale.rank === 1,\n () => `Error in batchNorm4D: scale must be rank 4 or rank 1 ` +\n `but got rank ${$scale.rank}.`);\n }\n if ($offset != null) {\n util.assert(\n $offset.rank === 4 || $offset.rank === 1,\n () => `Error in batchNorm4D: offset must be rank 4 or rank 1 ` +\n `but got rank ${$offset.rank}.`);\n }\n return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);\n}\n\n/**\n * @deprecated Please use `tf.batchNorm4d` instead and note the positional\n * argument change of scale, offset, and varianceEpsilon.\n */\nfunction batchNormalization4d_(\n x: Tensor4D|TensorLike, mean: Tensor4D|Tensor1D|TensorLike,\n variance: Tensor4D|Tensor1D|TensorLike, varianceEpsilon = .001,\n scale?: Tensor4D|Tensor1D|TensorLike,\n offset?: Tensor4D|Tensor1D|TensorLike): Tensor4D {\n warnDeprecation();\n return batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon);\n}\n\n// todo(yassogba): Remove batchNormalization4d since it is deprecated.\nexport const batchNormalization4d = op({batchNormalization4d_});\nexport const batchNorm4d = op({batchNorm4d_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {KernelBackend} from '../backends/backend';\nimport {ENGINE} from '../engine';\nimport {BroadcastTo, BroadCastToAttrs, BroadcastToInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {Rank, ShapeMap, TensorLike} from '../types';\n\nimport {op} from './operation';\n\n/**\n * Broadcast an array to a compatible shape NumPy-style.\n *\n * The tensor's shape is compared to the broadcast shape from end to beginning.\n * Ones are prepended to the tensor's shape until is has the same length as\n * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is\n * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then\n * the input tensor is tiled N times along that axis (using tf.tile).\n *\n * @param input The tensor that is to be broadcasted.\n * @param shape The input is to be broadcast to this shape.\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction broadcastTo_(\n x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor {\n let input = convertToTensor(x, 'broadcastTo', 'x');\n const xShape = input.shape;\n\n if (shape.some(d => !(d > 0) || d % 1 !== 0)) {\n throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);\n }\n\n if (shape.length < input.rank) {\n throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${\n input.rank}.`);\n }\n\n if (shape.length > input.rank) {\n const newShape = input.shape.slice();\n while (newShape.length < shape.length) {\n newShape.unshift(1);\n }\n input = input.reshape(newShape);\n }\n\n const inputShape = input.shape;\n const reps: number[] = Array.from(shape);\n for (let i = shape.length - 1; i >= 0; i--) {\n if (inputShape[i] === shape[i]) {\n reps[i] = 1;\n } else if (input.shape[i] !== 1) {\n throw new Error(\n `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);\n }\n }\n const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);\n\n if (axes.length === 0) {\n return input.clone() as Tensor;\n }\n\n const forward = (backend: KernelBackend) => backend.tile(input, reps);\n const keepDims = true;\n const backward = (dy: Tensor) => ({x: () => dy.sum(axes, keepDims)});\n\n const inputs: BroadcastToInputs = {x: input};\n const attrs: BroadCastToAttrs = {shape, inputShape};\n\n return ENGINE.runKernelFunc(\n forward, inputs as unknown as NamedTensorMap, backward,\n BroadcastTo, attrs as unknown as NamedAttrMap) as Tensor;\n}\n\nexport const broadcastTo = op({broadcastTo_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Identity} from '../kernel_names';\n\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {op} from './operation';\n\n/**\n * Creates a new tensor with the same values and shape as the specified\n * tensor.\n *\n * ```js\n * const x = tf.tensor([1, 2]);\n *\n * x.clone().print();\n * ```\n *\n * @param x The tensor to clone.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction clone_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'clone', null);\n const forward = () =>\n ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype) as T;\n\n // Note this op is called tf.identity in python. Hence the kernel name used\n // here.\n return ENGINE.runKernelFunc(forward, {x: $x}, null /* grad */, Identity);\n}\n\nexport const clone = op({clone_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {whereImpl} from '../backends/where_impl';\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor2D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assert, assertShapesMatch} from '../util';\nimport {assertAndGetBroadcastShape} from './broadcast_util';\nimport {op} from './operation';\nimport {zerosLike} from './tensor_ops';\n\n/**\n * Returns the truth value of `NOT x` element-wise.\n *\n * ```js\n * const a = tf.tensor1d([false, true], 'bool');\n *\n * a.logicalNot().print();\n * ```\n *\n * @param x The input tensor. Must be of dtype 'bool'.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction logicalNot_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');\n return ENGINE.runKernelFunc(backend => backend.logicalNot($x), {$x});\n}\n\n/**\n * Returns the truth value of `a AND b` element-wise. Supports broadcasting.\n *\n * ```js\n * const a = tf.tensor1d([false, false, true, true], 'bool');\n * const b = tf.tensor1d([false, true, false, true], 'bool');\n *\n * a.logicalAnd(b).print();\n * ```\n *\n * @param a The first input tensor. Must be of dtype bool.\n * @param b The second input tensor. Must be of dtype bool.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction logicalAnd_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');\n const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n return ENGINE.runKernelFunc(\n backend => backend.logicalAnd($a, $b), {a: $a, b: $b},\n null /* grad */, 'LogicalAnd') as T;\n}\n\n/**\n * Returns the truth value of `a OR b` element-wise. Supports broadcasting.\n *\n * ```js\n * const a = tf.tensor1d([false, false, true, true], 'bool');\n * const b = tf.tensor1d([false, true, false, true], 'bool');\n *\n * a.logicalOr(b).print();\n * ```\n * @param a The first input tensor. Must be of dtype bool.\n * @param b The second input tensor. Must be of dtype bool.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction logicalOr_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'logicalOr', 'bool');\n const $b = convertToTensor(b, 'b', 'logicalOr', 'bool');\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n return ENGINE.runKernelFunc(backend => backend.logicalOr($a, $b), {$a, $b}) as\n T;\n}\n\n/**\n * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.\n *\n * ```js\n * const a = tf.tensor1d([false, false, true, true], 'bool');\n * const b = tf.tensor1d([false, true, false, true], 'bool');\n *\n * a.logicalXor(b).print();\n * ```\n *\n * @param a The first input tensor. Must be of dtype bool.\n * @param b The second input tensor. Must be of dtype bool.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction logicalXor_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'logicalXor', 'bool');\n const $b = convertToTensor(b, 'b', 'logicalXor', 'bool');\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n // x ^ y = (x | y) & ~(x & y)\n return logicalOr(a, b).logicalAnd(logicalAnd(a, b).logicalNot()) as T;\n}\n\n/**\n * Returns the elements, either `a` or `b` depending on the `condition`.\n *\n * If the condition is true, select from `a`, otherwise select from `b`.\n *\n * ```js\n * const cond = tf.tensor1d([false, false, true], 'bool');\n * const a = tf.tensor1d([1 , 2, 3]);\n * const b = tf.tensor1d([-1, -2, -3]);\n *\n * a.where(cond, b).print();\n * ```\n *\n * @param condition The input condition. Must be of dtype bool.\n * @param a If `condition` is rank 1, `a` may have a higher rank but\n * its first dimension must match the size of `condition`.\n * @param b A tensor with the same shape and type as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction where_(\n condition: Tensor|TensorLike, a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'where');\n const $b = convertToTensor(b, 'b', 'where');\n const $condition = convertToTensor(condition, 'condition', 'where', 'bool');\n\n assertShapesMatch($a.shape, $b.shape, 'Error in where: ');\n\n if ($condition.rank === 1) {\n // If condition rank is 1, then the first dimension must match the size of\n // condition.\n assert(\n $condition.shape[0] === $a.shape[0],\n () => 'The first dimension of `a` must match the size of `condition`.');\n } else {\n // A must have the same shape as condition.\n assertShapesMatch($condition.shape, $b.shape, 'Error in where: ');\n }\n\n // TODO(julianoks): Return null for condition gradient\n // when backprop supports it.\n const grad = (dy: T, saved: Tensor[]) => {\n const [$condition] = saved;\n return {\n $condition: () => zerosLike($condition).toFloat(),\n $a: () => dy.mul($condition.cast(dy.dtype)),\n $b: () => dy.mul($condition.logicalNot().cast(dy.dtype))\n } as {$a: () => T, $b: () => T, $condition: () => T};\n };\n\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.select($condition, $a, $b);\n save([$condition]);\n return res;\n }, {$condition, $a, $b}, grad) as T;\n}\n\n/**\n * Returns the coordinates of true elements of condition.\n *\n * The coordinates are returned in a 2-D tensor where the first dimension (rows)\n * represents the number of true elements, and the second dimension (columns)\n * represents the coordinates of the true elements. Keep in mind, the shape of\n * the output tensor can vary depending on how many true values there are in\n * input. Indices are output in row-major order. The resulting tensor has the\n * shape `[numTrueElems, condition.rank]`.\n *\n * This is analogous to calling the python `tf.where(cond)` without an x or y.\n *\n * ```js\n * const cond = tf.tensor1d([false, false, true], 'bool');\n * const result = await tf.whereAsync(cond);\n * result.print();\n * ```\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nasync function whereAsync_(condition: Tensor|TensorLike): Promise {\n const $condition =\n convertToTensor(condition, 'condition', 'whereAsync', 'bool');\n const vals = await $condition.data();\n const res = whereImpl($condition.shape, vals);\n if (condition !== $condition) {\n $condition.dispose();\n }\n return res;\n}\n\nexport const logicalAnd = op({logicalAnd_});\nexport const logicalNot = op({logicalNot_});\nexport const logicalOr = op({logicalOr_});\nexport const logicalXor = op({logicalXor_});\nexport const where = op({where_});\nexport const whereAsync = whereAsync_;\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {div} from './div';\nimport {where} from './logical_ops';\nimport {op} from './operation';\nimport {zerosLike} from './tensor_ops';\n\n/**\n * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0\n * if denominator is 0.\n *\n * We also expose `tf.divStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 9, 16]);\n * const b = tf.tensor1d([1, 2, 3, 4]);\n * const c = tf.tensor1d([0, 0, 0, 0]);\n *\n * a.divNoNan(b).print(); // or tf.divNoNan(a, b)\n * a.divNoNan(c).print(); // or tf.divNoNan(a, c)\n * ```\n *\n * ```js\n * // Broadcast div a with b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(2);\n * const c = tf.scalar(0);\n *\n * a.divNoNan(b).print(); // or tf.divNoNan(a, b)\n * a.divNoNan(c).print(); // or tf.divNoNan(a, c)\n * ```\n *\n * @param a The first tensor as the numerator.\n * @param b The second tensor as the denominator. Must have the same dtype as\n * `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction divNoNan_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n // TODO: Make this into its own kernel.\n let $a = convertToTensor(a, 'a', 'div');\n let $b = convertToTensor(b, 'b', 'div');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const divResult = div($a, $b);\n const zeros = zerosLike(divResult);\n const bEqualsZero = $b.equal(zeros);\n return where(bEqualsZero, zeros, divResult) as T;\n}\n\nexport const divNoNan = op({divNoNan_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {Tile, TileAttrs, TileInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {DataType, TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {op} from './operation';\n\n/**\n * Construct a tensor by repeating it the number of times given by reps.\n *\n * This operation creates a new tensor by replicating `input` `reps`\n * times. The output tensor's i'th dimension has `input.shape[i] *\n * reps[i]` elements, and the values of `input` are replicated\n * `reps[i]` times along the i'th dimension. For example, tiling\n * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.\n *\n * ```js\n * const a = tf.tensor1d([1, 2]);\n *\n * a.tile([2]).print(); // or a.tile([2])\n * ```\n *\n * ```js\n * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * a.tile([1, 2]).print(); // or a.tile([1, 2])\n * ```\n * @param x The tensor to tile.\n * @param reps Determines the number of replications per dimension.\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction tile_(x: T|TensorLike, reps: number[]): T {\n const parseAs: DataType = null;\n const $x = convertToTensor(x, 'x', 'tile', parseAs);\n util.assert(\n $x.rank === reps.length,\n () => `Error in transpose: rank of input ${$x.rank} ` +\n `must match length of reps ${reps}.`);\n\n const forward: ForwardFunc = (backend, save) => {\n const res = backend.tile($x, reps);\n save([$x]);\n return res;\n };\n\n const inputsToSave = [$x];\n const inputs: TileInputs = {x: $x};\n const attrs: TileAttrs = {reps};\n\n return ENGINE.runKernelFunc(\n forward, inputs as unknown as NamedTensorMap, null /* grad */, Tile,\n attrs as unknown as NamedAttrMap, inputsToSave);\n}\n\nexport const tile = op({tile_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor2D} from '../tensor';\nimport {DataType} from '../types';\n\nimport {buffer, expandDims} from './array_ops';\nimport {op} from './operation';\nimport {tile} from './tile';\n\n/**\n * Create an identity matrix.\n *\n * @param numRows Number of rows.\n * @param numColumns Number of columns. Defaults to `numRows`.\n * @param batchShape If provided, will add the batch shape to the beginning\n * of the shape of the returned `tf.Tensor` by repeating the identity\n * matrix.\n * @param dtype Data type.\n * @returns Identity matrix of the specified size and data type, possibly\n * with batch repetition if `batchShape` is specified.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction eye_(\n numRows: number, numColumns?: number,\n batchShape?:\n [\n number\n ]|[number,\n number]|[number, number, number]|[number, number, number, number],\n dtype: DataType = 'float32'): Tensor2D {\n if (numColumns == null) {\n numColumns = numRows;\n }\n const buff = buffer([numRows, numColumns], dtype);\n const n = numRows <= numColumns ? numRows : numColumns;\n for (let i = 0; i < n; ++i) {\n buff.set(1, i, i);\n }\n const out = buff.toTensor().as2D(numRows, numColumns);\n if (batchShape == null) {\n return out;\n } else {\n if (batchShape.length === 1) {\n return tile(expandDims(out, 0), [batchShape[0], 1, 1]);\n } else if (batchShape.length === 2) {\n return tile(\n expandDims(expandDims(out, 0), 0),\n [batchShape[0], batchShape[1], 1, 1]);\n } else if (batchShape.length === 3) {\n return tile(\n expandDims(expandDims(expandDims(out, 0), 0), 0),\n [batchShape[0], batchShape[1], batchShape[2], 1, 1]);\n } else {\n throw new Error(\n `eye() currently supports only 1D and 2D ` +\n // tslint:disable-next-line:no-any\n `batchShapes, but received ${(batchShape as any).length}D.`);\n }\n }\n}\n\nexport const eye = op({eye_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor1D, Tensor2D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {op} from './operation';\n\n/**\n * Creates a `tf.Tensor` with values drawn from a multinomial distribution.\n *\n * ```js\n * const probs = tf.tensor([.75, .25]);\n * tf.multinomial(probs, 3).print();\n * ```\n *\n * @param logits 1D array with unnormalized log-probabilities, or\n * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`\n * parameter.\n * @param numSamples Number of samples to draw for each row slice.\n * @param seed The seed number.\n * @param normalized Whether the provided `logits` are normalized true\n * probabilities (sum to 1). Defaults to false.\n * @return 1D array of shape `[numSamples]`, or 2D array of shape\n * `[batchSize, numSamples]`, depending on the rank of the input.\n */\n/** @doc {heading: 'Tensors', subheading: 'Random'} */\nfunction multinomial_(\n logits: Tensor1D|Tensor2D|TensorLike, numSamples: number, seed?: number,\n normalized = false): Tensor1D|Tensor2D {\n const $logits = convertToTensor(logits, 'logits', 'multinomial');\n const numOutcomes = $logits.size;\n const origRank = $logits.rank;\n if (numOutcomes < 2) {\n throw new Error(\n `Error in multinomial: you need at least 2 outcomes, but got ` +\n `${numOutcomes}.`);\n }\n if (origRank > 2) {\n throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);\n }\n seed = seed || Math.random();\n const logits2D = origRank === 1 ? $logits.as2D(1, -1) : $logits as Tensor2D;\n const res = ENGINE.runKernelFunc(\n backend => backend.multinomial(logits2D, normalized, numSamples, seed),\n {logits2D});\n return origRank === 1 ? res.as1D() : res;\n}\n\nexport const multinomial = op({multinomial_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {OneHot, OneHotAttrs, OneHotInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor, Tensor1D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {op} from './operation';\n\n/**\n * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take\n * value `onValue` (defaults to 1), while all other locations take value\n * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank\n * `R+1` with the last axis of size `depth`.\n *\n * ```js\n * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();\n * ```\n *\n * @param indices `tf.Tensor` of indices with dtype `int32`.\n * @param depth The depth of the one hot dimension.\n * @param onValue A number used to fill in the output when the index matches\n * the location.\n * @param offValue A number used to fill in the output when the index does\n * not match the location.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction oneHot_(\n indices: Tensor|TensorLike, depth: number, onValue = 1,\n offValue = 0): Tensor {\n if (depth < 2) {\n throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);\n }\n let $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');\n const outShape = [...$indices.shape, depth];\n $indices = $indices.flatten();\n\n const forward: ForwardFunc = (backend, save) => {\n save([$indices]);\n return backend.oneHot($indices as Tensor1D, depth, onValue, offValue);\n };\n\n const inputs: OneHotInputs = {indices: $indices};\n const attrs: OneHotAttrs = {depth, onValue, offValue};\n\n const result = ENGINE.runKernelFunc(\n forward, inputs as unknown as NamedTensorMap, null /* grad */, OneHot,\n attrs as unknown as NamedAttrMap);\n return result.reshape(outShape);\n}\n\nexport const oneHot = op({oneHot_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {PadV2, PadV2Attrs, PadV2Inputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {op} from './operation';\n\n/**\n * Pads a `tf.Tensor` with a given value and paddings.\n *\n * This operation currently only implements the `CONSTANT` mode.\n *\n * Also available are stricter rank-specific methods with the same signature\n * as this method that assert that `paddings` is of given length.\n * - `tf.pad1d`\n * - `tf.pad2d`\n * - `tf.pad3d`\n * - `tf.pad4d`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n * x.pad([[1, 2]]).print();\n * ```\n * @param x The tensor to pad.\n * @param paddings An array of length `R` (the rank of the tensor), where\n * each element is a length-2 tuple of ints `[padBefore, padAfter]`,\n * specifying how much to pad along each dimension of the tensor.\n * @param constantValue The pad value to use. Defaults to 0.\n */\n/** @doc {heading: 'Tensors', subheading: 'Transformations'} */\nfunction pad_(\n x: T|TensorLike, paddings: Array<[number, number]>, constantValue = 0): T {\n const $x = convertToTensor(x, 'x', 'pad');\n if ($x.rank === 0) {\n throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');\n }\n const forward: ForwardFunc = (backend, save) => {\n save([$x]);\n return backend.pad($x, paddings, constantValue);\n };\n\n const attrs: PadV2Attrs = {paddings, constantValue};\n const inputs: PadV2Inputs = {x: $x};\n return ENGINE.runKernelFunc(\n forward, inputs as unknown as NamedTensorMap, null /* grad */, PadV2,\n attrs as unknown as NamedAttrMap);\n}\n\nexport const pad = op({pad_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor1D} from '../tensor';\nimport {TensorLike} from '../types';\nimport {assert} from '../util';\nimport {op} from './operation';\nimport {pad} from './pad';\n\n/**\n * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.\n */\nfunction pad1d_(\n x: Tensor1D|TensorLike, paddings: [number, number],\n constantValue = 0): Tensor1D {\n assert(\n paddings.length === 2,\n () => 'Invalid number of paddings. Must be length of 2.');\n return pad(x, [paddings], constantValue);\n}\n\nexport const pad1d = op({pad1d_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor2D} from '../tensor';\nimport {TensorLike} from '../types';\nimport {assert} from '../util';\nimport {op} from './operation';\nimport {pad} from './pad';\n\n/**\n * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.\n */\nfunction pad2d_(\n x: Tensor2D|TensorLike, paddings: [[number, number], [number, number]],\n constantValue = 0): Tensor2D {\n assert(\n paddings.length === 2 && paddings[0].length === 2 &&\n paddings[1].length === 2,\n () => 'Invalid number of paddings. Must be length of 2 each.');\n return pad(x, paddings, constantValue);\n}\n\nexport const pad2d = op({pad2d_});\n","import {Tensor3D} from '../tensor';\nimport {TensorLike} from '../types';\nimport {assert} from '../util';\nimport {op} from './operation';\nimport {pad} from './pad';\n\n/**\n * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.\n */\nfunction pad3d_(\n x: Tensor3D|TensorLike,\n paddings: [[number, number], [number, number], [number, number]],\n constantValue = 0): Tensor3D {\n assert(\n paddings.length === 3 && paddings[0].length === 2 &&\n paddings[1].length === 2 && paddings[2].length === 2,\n () => 'Invalid number of paddings. Must be length of 2 each.');\n return pad(x, paddings, constantValue);\n}\n\nexport const pad3d = op({pad3d_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor4D} from '../tensor';\nimport {TensorLike} from '../types';\nimport {assert} from '../util';\nimport {op} from './operation';\nimport {pad} from './pad';\n\n/**\n * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.\n */\nfunction pad4d_(\n x: Tensor4D|TensorLike,\n paddings:\n [\n [number, number], [number, number], [number, number], [number, number]\n ],\n constantValue = 0): Tensor4D {\n assert(\n paddings.length === 4 && paddings[0].length === 2 &&\n paddings[1].length === 2 && paddings[2].length === 2 &&\n paddings[3].length === 2,\n () => 'Invalid number of paddings. Must be length of 2 each.');\n return pad(x, paddings, constantValue);\n}\n\nexport const pad4d = op({pad4d_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {DataType, Rank, ShapeMap} from '../types';\nimport {sizeFromShape} from '../util';\n\nimport {op} from './operation';\n\n/**\n * Creates a `tf.Tensor` with values sampled from a random number generator\n * function defined by the user.\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param randFunction A random number generator function which is called\n * for each element in the output tensor.\n * @param dtype The data type of the output tensor. Defaults to 'float32'.\n */\nfunction rand_(\n shape: ShapeMap[R], randFunction: () => number,\n dtype?: DataType): Tensor {\n const size = sizeFromShape(shape);\n let values = null;\n if (dtype == null || dtype === 'float32') {\n values = new Float32Array(size);\n } else if (dtype === 'int32') {\n values = new Int32Array(size);\n } else if (dtype === 'bool') {\n values = new Uint8Array(size);\n } else {\n throw new Error(`Unknown data type ${dtype}`);\n }\n for (let i = 0; i < size; i++) {\n values[i] = randFunction();\n }\n return ENGINE.makeTensor(values, shape, dtype) as Tensor;\n}\n\nexport const rand = op({rand_});\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from './engine';\nimport {inferShape} from './tensor_util_env';\nimport {RecursiveArray, TensorLike, TypedArray} from './types';\nimport {arraysEqual, flatten, isString, isTypedArray} from './util';\n\nconst TEST_EPSILON_FLOAT32 = 1e-3;\nexport const TEST_EPSILON_FLOAT16 = 1e-1;\n\nexport function expectArraysClose(\n actual: TypedArray|number|RecursiveArray,\n expected: TypedArray|number|RecursiveArray, epsilon?: number) {\n if (epsilon == null) {\n epsilon = testEpsilon();\n }\n return expectArraysPredicate(\n actual, expected, (a, b) => areClose(a as number, b as number, epsilon));\n}\n\nexport function testEpsilon() {\n return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 :\n TEST_EPSILON_FLOAT16;\n}\n\nfunction expectArraysPredicate(\n actual: TensorLike, expected: TensorLike,\n predicate: (a: {}, b: {}) => boolean) {\n let checkClassType = true;\n if (isTypedArray(actual) || isTypedArray(expected)) {\n checkClassType = false;\n }\n if (isTypedArray(actual) && isTypedArray(expected)) {\n checkClassType = true;\n }\n if (checkClassType) {\n const aType = actual.constructor.name;\n const bType = expected.constructor.name;\n\n if (aType !== bType) {\n throw new Error(\n `Arrays are of different type. Actual: ${aType}. ` +\n `Expected: ${bType}`);\n }\n }\n\n if (Array.isArray(actual) && Array.isArray(expected)) {\n const actualShape = inferShape(actual);\n const expectedShape = inferShape(expected);\n if (!arraysEqual(actualShape, expectedShape)) {\n throw new Error(\n `Arrays have different shapes. ` +\n `Actual: [${actualShape}]. Expected: [${expectedShape}]`);\n }\n }\n\n const actualFlat =\n isTypedArray(actual) ? actual : flatten(actual as RecursiveArray);\n const expectedFlat = isTypedArray(expected) ?\n expected :\n flatten(expected as RecursiveArray);\n\n if (actualFlat.length !== expectedFlat.length) {\n throw new Error(\n `Arrays have different lengths actual: ${actualFlat.length} vs ` +\n `expected: ${expectedFlat.length}.\\n` +\n `Actual: ${actualFlat}.\\n` +\n `Expected: ${expectedFlat}.`);\n }\n for (let i = 0; i < expectedFlat.length; ++i) {\n const a = actualFlat[i];\n const e = expectedFlat[i];\n\n if (!predicate(a, e)) {\n throw new Error(\n `Arrays differ: actual[${i}] = ${a}, expected[${i}] = ${e}.\\n` +\n `Actual: ${actualFlat}.\\n` +\n `Expected: ${expectedFlat}.`);\n }\n }\n}\n\nexport interface DoneFn {\n (): void;\n fail: (message?: Error|string) => void;\n}\n\nexport function expectPromiseToFail(fn: () => Promise<{}>, done: DoneFn): void {\n fn().then(() => done.fail(), () => done());\n}\n\nexport function expectArraysEqual(actual: TensorLike, expected: TensorLike) {\n const exp = typeof expected === 'string' || typeof expected === 'number' ||\n typeof expected === 'boolean' ?\n [expected] as number[] :\n expected as number[];\n if (isString(actual) || isString((actual as string[])[0]) ||\n isString(expected) || isString((expected as string[])[0])) {\n // tslint:disable-next-line: triple-equals\n return expectArraysPredicate(actual, exp, (a, b) => a == b);\n }\n return expectArraysPredicate(\n actual, expected, (a, b) => areClose(a as number, b as number, 0));\n}\n\nexport function expectNumbersClose(a: number, e: number, epsilon?: number) {\n if (epsilon == null) {\n epsilon = testEpsilon();\n }\n if (!areClose(a, e, epsilon)) {\n throw new Error(`Numbers differ: actual === ${a}, expected === ${e}`);\n }\n}\n\nfunction areClose(a: number, e: number, epsilon: number): boolean {\n if (!isFinite(a) && !isFinite(e)) {\n return true;\n }\n if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {\n return false;\n }\n return true;\n}\n\nexport function expectValuesInRange(\n actual: TypedArray|number[], low: number, high: number) {\n for (let i = 0; i < actual.length; i++) {\n if (actual[i] < low || actual[i] > high) {\n throw new Error(\n `Value out of range:${actual[i]} low: ${low}, high: ${high}`);\n }\n }\n}\n\nexport function expectArrayBuffersEqual(\n actual: ArrayBuffer, expected: ArrayBuffer) {\n // Safari & Jasmine don't like comparing ArrayBuffers directly. Wrapping in\n // a Float32Array solves this issue.\n expect(new Float32Array(actual)).toEqual(new Float32Array(expected));\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as seedrandom from 'seedrandom';\n\nimport {expectNumbersClose, testEpsilon} from '../test_util';\nimport {TypedArray} from '../types';\n\nexport interface RandomBase {\n nextValue(): number;\n}\n\nexport interface RandomGamma {\n nextValue(): number;\n}\n\nexport interface RandNormalDataTypes {\n float32: Float32Array;\n int32: Int32Array;\n}\n\nexport interface RandGammaDataTypes {\n float32: Float32Array;\n int32: Int32Array;\n}\n\n// https://en.wikipedia.org/wiki/Marsaglia_polar_method\nexport class MPRandGauss implements RandomBase {\n private mean: number;\n private stdDev: number;\n private nextVal: number;\n private dtype?: keyof RandNormalDataTypes;\n private truncated?: boolean;\n private upper?: number;\n private lower?: number;\n private random: seedrandom.prng;\n\n constructor(\n mean: number, stdDeviation: number, dtype?: keyof RandNormalDataTypes,\n truncated?: boolean, seed?: number) {\n this.mean = mean;\n this.stdDev = stdDeviation;\n this.dtype = dtype;\n this.nextVal = NaN;\n this.truncated = truncated;\n if (this.truncated) {\n this.upper = this.mean + this.stdDev * 2;\n this.lower = this.mean - this.stdDev * 2;\n }\n const seedValue = seed ? seed : Math.random();\n this.random = seedrandom.alea(seedValue.toString());\n }\n\n /** Returns next sample from a Gaussian distribution. */\n public nextValue(): number {\n if (!isNaN(this.nextVal)) {\n const value = this.nextVal;\n this.nextVal = NaN;\n return value;\n }\n\n let resultX: number, resultY: number;\n let isValid = false;\n while (!isValid) {\n let v1: number, v2: number, s: number;\n do {\n v1 = 2 * this.random() - 1;\n v2 = 2 * this.random() - 1;\n s = v1 * v1 + v2 * v2;\n } while (s >= 1 || s === 0);\n\n const mul = Math.sqrt(-2.0 * Math.log(s) / s);\n resultX = this.mean + this.stdDev * v1 * mul;\n resultY = this.mean + this.stdDev * v2 * mul;\n\n if (!this.truncated || this.isValidTruncated(resultX)) {\n isValid = true;\n }\n }\n\n if (!this.truncated || this.isValidTruncated(resultY)) {\n this.nextVal = this.convertValue(resultY);\n }\n return this.convertValue(resultX);\n }\n\n /** Handles proper rounding for non-floating-point numbers. */\n private convertValue(value: number): number {\n if (this.dtype == null || this.dtype === 'float32') {\n return value;\n }\n return Math.round(value);\n }\n\n /** Returns true if less than 2-standard-deviations from the mean. */\n private isValidTruncated(value: number): boolean {\n return value <= this.upper && value >= this.lower;\n }\n}\n\n// Marsaglia, George, and Wai Wan Tsang. 2000. \"A Simple Method for Generating\n// Gamma Variables.\"\nexport class RandGamma implements RandomGamma {\n private alpha: number;\n private beta: number;\n private d: number;\n private c: number;\n private dtype?: keyof RandGammaDataTypes;\n private randu: seedrandom.prng;\n private randn: MPRandGauss;\n\n constructor(\n alpha: number, beta: number, dtype: keyof RandGammaDataTypes,\n seed?: number) {\n this.alpha = alpha;\n this.beta = 1 / beta; // convert rate to scale parameter\n this.dtype = dtype;\n\n const seedValue = seed ? seed : Math.random();\n this.randu = seedrandom.alea(seedValue.toString());\n this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());\n\n if (alpha < 1) {\n this.d = alpha + (2 / 3);\n } else {\n this.d = alpha - (1 / 3);\n }\n this.c = 1 / Math.sqrt(9 * this.d);\n }\n\n /** Returns next sample from a gamma distribution. */\n public nextValue(): number {\n let x2: number, v0: number, v1: number, x: number, u: number, v: number;\n while (true) {\n do {\n x = this.randn.nextValue();\n v = 1 + (this.c * x);\n } while (v <= 0);\n v *= v * v;\n x2 = x * x;\n v0 = 1 - (0.331 * x2 * x2);\n v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v)));\n u = this.randu();\n if (u < v0 || Math.log(u) < v1) {\n break;\n }\n }\n v = (1 / this.beta) * this.d * v;\n if (this.alpha < 1) {\n v *= Math.pow(this.randu(), 1 / this.alpha);\n }\n return this.convertValue(v);\n }\n /** Handles proper rounding for non-floating-point numbers. */\n private convertValue(value: number): number {\n if (this.dtype === 'float32') {\n return value;\n }\n return Math.round(value);\n }\n}\n\nexport class UniformRandom implements RandomBase {\n private min: number;\n private range: number;\n private random: seedrandom.prng;\n private dtype?: keyof RandNormalDataTypes;\n\n constructor(\n min = 0, max = 1, dtype?: keyof RandNormalDataTypes,\n seed?: string|number) {\n this.min = min;\n this.range = max - min;\n this.dtype = dtype;\n if (seed == null) {\n seed = Math.random();\n }\n if (typeof seed === 'number') {\n seed = seed.toString();\n }\n\n if (!this.canReturnFloat() && this.range <= 1) {\n throw new Error(\n `The difference between ${min} - ${max} <= 1 and dtype is not float`);\n }\n this.random = seedrandom.alea(seed);\n }\n\n /** Handles proper rounding for non floating point numbers. */\n private canReturnFloat = () =>\n (this.dtype == null || this.dtype === 'float32');\n\n private convertValue(value: number): number {\n if (this.canReturnFloat()) {\n return value;\n }\n return Math.round(value);\n }\n\n nextValue() {\n return this.convertValue(this.min + this.range * this.random());\n }\n}\n\nexport function jarqueBeraNormalityTest(values: TypedArray|number[]) {\n // https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test\n const n = values.length;\n const s = skewness(values);\n const k = kurtosis(values);\n const jb = n / 6 * (Math.pow(s, 2) + 0.25 * Math.pow(k - 3, 2));\n // JB test requires 2-degress of freedom from Chi-Square @ 0.95:\n // http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm\n const CHI_SQUARE_2DEG = 5.991;\n if (jb > CHI_SQUARE_2DEG) {\n throw new Error(`Invalid p-value for JB: ${jb}`);\n }\n}\n\nexport function expectArrayInMeanStdRange(\n actual: TypedArray|number[], expectedMean: number, expectedStdDev: number,\n epsilon?: number) {\n if (epsilon == null) {\n epsilon = testEpsilon();\n }\n const actualMean = mean(actual);\n expectNumbersClose(actualMean, expectedMean, epsilon);\n expectNumbersClose(\n standardDeviation(actual, actualMean), expectedStdDev, epsilon);\n}\n\nfunction mean(values: TypedArray|number[]) {\n let sum = 0;\n for (let i = 0; i < values.length; i++) {\n sum += values[i];\n }\n return sum / values.length;\n}\n\nfunction standardDeviation(values: TypedArray|number[], mean: number) {\n let squareDiffSum = 0;\n for (let i = 0; i < values.length; i++) {\n const diff = values[i] - mean;\n squareDiffSum += diff * diff;\n }\n return Math.sqrt(squareDiffSum / values.length);\n}\n\nfunction kurtosis(values: TypedArray|number[]) {\n // https://en.wikipedia.org/wiki/Kurtosis\n const valuesMean = mean(values);\n const n = values.length;\n let sum2 = 0;\n let sum4 = 0;\n for (let i = 0; i < n; i++) {\n const v = values[i] - valuesMean;\n sum2 += Math.pow(v, 2);\n sum4 += Math.pow(v, 4);\n }\n return (1 / n) * sum4 / Math.pow((1 / n) * sum2, 2);\n}\n\nfunction skewness(values: TypedArray|number[]) {\n // https://en.wikipedia.org/wiki/Skewness\n const valuesMean = mean(values);\n const n = values.length;\n let sum2 = 0;\n let sum3 = 0;\n for (let i = 0; i < n; i++) {\n const v = values[i] - valuesMean;\n sum2 += Math.pow(v, 2);\n sum3 += Math.pow(v, 3);\n }\n return (1 / n) * sum3 / Math.pow((1 / (n - 1)) * sum2, 3 / 2);\n}\n\nexport interface RandomBase {\n nextValue(): number;\n}\n\nexport interface RandomGamma {\n nextValue(): number;\n}\n\nexport interface RandNormalDataTypes {\n float32: Float32Array;\n int32: Int32Array;\n}\n\nexport interface RandGammaDataTypes {\n float32: Float32Array;\n int32: Int32Array;\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {Rank, ShapeMap} from '../types';\n\nimport {buffer} from './array_ops';\nimport {op} from './operation';\nimport {RandGamma} from './rand_util';\n\n/**\n * Creates a `tf.Tensor` with values sampled from a gamma distribution.\n *\n * ```js\n * tf.randomGamma([2, 2], 1).print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param alpha The shape parameter of the gamma distribution.\n * @param beta The inverse scale parameter of the gamma distribution. Defaults\n * to 1.\n * @param dtype The data type of the output. Defaults to float32.\n * @param seed The seed for the random number generator.\n */\n/** @doc {heading: 'Tensors', subheading: 'Random'} */\nfunction randomGamma_(\n shape: ShapeMap[R], alpha: number, beta = 1,\n dtype: 'float32'|'int32' = 'float32', seed?: number): Tensor {\n if (beta == null) {\n beta = 1;\n }\n if (dtype == null) {\n dtype = 'float32';\n }\n if (dtype !== 'float32' && dtype !== 'int32') {\n throw new Error(`Unsupported data type ${dtype}`);\n }\n const rgamma = new RandGamma(alpha, beta, dtype, seed);\n const res = buffer(shape, dtype);\n for (let i = 0; i < res.values.length; i++) {\n res.values[i] = rgamma.nextValue();\n }\n return res.toTensor();\n}\n\nexport const randomGamma = op({randomGamma_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {DataType, Rank, ShapeMap} from '../types';\n\nimport {buffer} from './array_ops';\nimport {op} from './operation';\nimport {MPRandGauss} from './rand_util';\n\n/**\n * Creates a `tf.Tensor` with values sampled from a normal distribution.\n *\n * ```js\n * tf.randomNormal([2, 2]).print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param mean The mean of the normal distribution.\n * @param stdDev The standard deviation of the normal distribution.\n * @param dtype The data type of the output.\n * @param seed The seed for the random number generator.\n */\n/** @doc {heading: 'Tensors', subheading: 'Random'} */\nfunction randomNormal_(\n shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',\n seed?: number): Tensor {\n if (dtype != null && (dtype as DataType) === 'bool') {\n throw new Error(`Unsupported data type ${dtype}`);\n }\n const randGauss =\n new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);\n const res = buffer(shape, dtype);\n for (let i = 0; i < res.values.length; i++) {\n res.values[i] = randGauss.nextValue();\n }\n return res.toTensor();\n}\n\nexport const randomNormal = op({randomNormal_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {DataType, Rank, ShapeMap} from '../types';\n\nimport {buffer} from './array_ops';\nimport {op} from './operation';\nimport {UniformRandom} from './rand_util';\n\n/**\n * Creates a `tf.Tensor` with values sampled from a uniform distribution.\n *\n * The generated values follow a uniform distribution in the range [minval,\n * maxval). The lower bound minval is included in the range, while the upper\n * bound maxval is excluded.\n *\n * ```js\n * tf.randomUniform([2, 2]).print();\n * ```\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param minval The lower bound on the range of random values to generate.\n * Defaults to 0.\n * @param maxval The upper bound on the range of random values to generate.\n * Defaults to 1.\n * @param dtype The data type of the output tensor. Defaults to 'float32'.\n */\n/** @doc {heading: 'Tensors', subheading: 'Random'} */\nfunction randomUniform_(\n shape: ShapeMap[R], minval = 0, maxval = 1, dtype: DataType = 'float32',\n seed?: number|string): Tensor {\n const res = buffer(shape, dtype);\n const random = new UniformRandom(minval, maxval, null, seed);\n for (let i = 0; i < res.values.length; i++) {\n res.values[i] = random.nextValue();\n }\n return res.toTensor();\n}\n\nexport const randomUniform = op({randomUniform_});\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {op} from './operation';\n\n/**\n * Computes square of `x` element-wise: `x ^ 2`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);\n *\n * x.square().print(); // or tf.square(x)\n * ```\n * @param x The input Tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction square_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'square');\n const attrs = {};\n const inputsToSave = [$x];\n const outputsToSave: boolean[] = [];\n return ENGINE.runKernelFunc((backend, save) => {\n save([$x]);\n return backend.square($x);\n }, {x: $x}, null /* grad */, 'Square', attrs, inputsToSave, outputsToSave);\n}\n\nexport const square = op({square_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {SquaredDifference, SquaredDifferenceInputs} from '../kernel_names';\nimport {Tensor} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {assertAndGetBroadcastShape} from './broadcast_util';\nimport {op} from './operation';\nimport {scalar} from './tensor_ops';\n\n/**\n * Returns (a - b) * (a - b) element-wise.\n * Supports broadcasting.\n *\n * We also expose `tf.squaredDifferenceStrict` which has the same signature as\n * this op and asserts that `a` and `b` are the same shape (does not\n * broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 4, 3, 16]);\n * const b = tf.tensor1d([1, 2, 9, 4]);\n *\n * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)\n * ```\n *\n * ```js\n * // Broadcast squared difference a with b.\n * const a = tf.tensor1d([2, 4, 6, 8]);\n * const b = tf.scalar(5);\n *\n * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)\n * ```\n *\n * @param a The first tensor.\n * @param b The second tensor. Must have the same type as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */\nfunction squaredDifference_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'squaredDifference');\n let $b = convertToTensor(b, 'b', 'squaredDifference');\n [$a, $b] = makeTypesMatch($a, $b);\n\n assertAndGetBroadcastShape($a.shape, $b.shape);\n const der = (dy: Tensor, saved: Tensor[]) => {\n const [$a, $b] = saved;\n const two = scalar(2);\n const derA = () => dy.mul($a.sub($b).mul(two));\n const derB = () => dy.mul($b.sub($a).mul(two));\n return {a: derA, b: derB};\n };\n const forward: ForwardFunc = (backend, save) => {\n const res = backend.squaredDifference($a, $b);\n save([$a, $b]);\n return res;\n };\n\n const inputs: SquaredDifferenceInputs = {a: $a, b: $b};\n const attrs = {};\n\n const inputsToSave = [$a, $b];\n const outputToSave: boolean[] = [];\n return ENGINE.runKernelFunc(\n forward, inputs as unknown as NamedTensorMap, der,\n SquaredDifference, attrs, inputsToSave, outputToSave) as T;\n}\n\nexport const squaredDifference = op({squaredDifference_});\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {DataType, Rank, ShapeMap} from '../types';\n\nimport {buffer} from './array_ops';\nimport {op} from './operation';\nimport {MPRandGauss} from './rand_util';\n\n/**\n * Creates a `tf.Tensor` with values sampled from a truncated normal\n * distribution.\n *\n * ```js\n * tf.truncatedNormal([2, 2]).print();\n * ```\n *\n * The generated values follow a normal distribution with specified mean and\n * standard deviation, except that values whose magnitude is more than 2\n * standard deviations from the mean are dropped and re-picked.\n *\n * @param shape An array of integers defining the output tensor shape.\n * @param mean The mean of the normal distribution.\n * @param stdDev The standard deviation of the normal distribution.\n * @param dtype The data type of the output tensor.\n * @param seed The seed for the random number generator.\n */\n/** @doc {heading: 'Tensors', subheading: 'Creation'} */\nfunction truncatedNormal_(\n shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',\n seed?: number): Tensor {\n if (dtype != null && (dtype as DataType) === 'bool') {\n throw new Error(`Unsupported data type $ { dtype }`);\n }\n const randGauss =\n new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);\n const res = buffer(shape, dtype);\n for (let i = 0; i < res.values.length; i++) {\n res.values[i] = randGauss.nextValue();\n }\n return res.toTensor();\n}\n\nexport const truncatedNormal = op({truncatedNormal_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assertShapesMatch} from '../util';\nimport {assertAndGetBroadcastShape} from './broadcast_util';\nimport {op} from './operation';\nimport {zerosLike} from './tensor_ops';\n\n/**\n * Returns the truth value of (a != b) element-wise. Supports broadcasting.\n *\n * We also expose `tf.notEqualStrict` which has the same signature as this op\n * and asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([0, 2, 3]);\n *\n * a.notEqual(b).print();\n * ```\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction notEqual_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'notEqual');\n let $b = convertToTensor(b, 'b', 'notEqual');\n [$a, $b] = makeTypesMatch($a, $b);\n assertAndGetBroadcastShape($a.shape, $b.shape);\n return ENGINE.runKernelFunc(\n backend => backend.notEqual($a, $b), {a: $a, b: $b},\n null /* grad */, 'NotEqual') as T;\n}\n\n/**\n * Strict version of `tf.notEqual` that forces `a` and `b` to be of the same\n * shape.\n *\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same shape and dtype as\n * `a`.\n */\nfunction notEqualStrict_(\n a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'notEqualStrict');\n const $b = convertToTensor(b, 'b', 'notEqualStrict');\n assertShapesMatch($a.shape, $b.shape, 'Error in notEqualStrict: ');\n return $a.notEqual($b);\n}\n\n/**\n * Returns the truth value of (a < b) element-wise. Supports broadcasting.\n *\n * We also expose `tf.lessStrict` which has the same signature as this op and\n * asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([2, 2, 2]);\n *\n * a.less(b).print();\n * ```\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction less_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'less');\n let $b = convertToTensor(b, 'b', 'less');\n [$a, $b] = makeTypesMatch($a, $b);\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n return ENGINE.runKernelFunc(\n backend => backend.less($a, $b), {a: $a, b: $b}, null /* grad */,\n 'Less') as T;\n}\n\n/**\n * Strict version of `tf.less` that forces `a` and `b` to be of the same\n * shape.\n *\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same shape and dtype as\n * `a`.\n */\nfunction lessStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'lessStrict');\n const $b = convertToTensor(b, 'b', 'lessStrict');\n assertShapesMatch($a.shape, $b.shape, 'Error in lessStrict: ');\n return $a.less($b);\n}\n\n/**\n * Returns the truth value of (a == b) element-wise. Supports broadcasting.\n *\n * We also expose `tf.equalStrict` which has the same signature as this op\n * and asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([2, 2, 2]);\n *\n * a.equal(b).print();\n * ```\n *\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction equal_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'equal');\n let $b = convertToTensor(b, 'b', 'equal');\n [$a, $b] = makeTypesMatch($a, $b);\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n return ENGINE.runKernelFunc(backend => backend.equal($a, $b), {$a, $b}) as T;\n}\n\nfunction equalStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'equalStrict');\n const $b = convertToTensor(b, 'b', 'equalStrict');\n assertShapesMatch($a.shape, $b.shape, 'Error in equalStrict: ');\n return $a.equal($b);\n}\n\n/**\n * Returns the truth value of (a <= b) element-wise. Supports broadcasting.\n *\n * We also expose `tf.lessEqualStrict` which has the same signature as this op\n * and asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([2, 2, 2]);\n *\n * a.lessEqual(b).print();\n * ```\n *\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction lessEqual_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'lessEqual');\n let $b = convertToTensor(b, 'b', 'lessEqual');\n [$a, $b] = makeTypesMatch($a, $b);\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.lessEqual($a, $b);\n save([$a, $b]);\n return res;\n }, {a: $a, b: $b}, null /* grad */, 'LessEqual') as T;\n}\n\nfunction lessEqualStrict_(\n a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'lessEqualStrict');\n const $b = convertToTensor(b, 'b', 'lessEqualStrict');\n assertShapesMatch($a.shape, $b.shape, 'Error in lessEqualStrict: ');\n return $a.lessEqual($b);\n}\n\n/**\n * Returns the truth value of (a > b) element-wise. Supports broadcasting.\n *\n * We also expose `tf.greaterStrict` which has the same signature as this\n * op and asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([2, 2, 2]);\n *\n * a.greater(b).print();\n * ```\n *\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction greater_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'greater');\n let $b = convertToTensor(b, 'b', 'greater');\n [$a, $b] = makeTypesMatch($a, $b);\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n return ENGINE.runKernelFunc(\n backend => backend.greater($a, $b), {a: $a, b: $b},\n null /* grad */, 'Greater') as T;\n}\n\nfunction greaterStrict_(a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'greaterStrict');\n const $b = convertToTensor(b, 'b', 'greaterStrict');\n assertShapesMatch($a.shape, $b.shape, 'Error in greaterStrict: ');\n return $a.greater($b);\n}\n\n/**\n * Returns the truth value of (a >= b) element-wise. Supports broadcasting.\n *\n * We also expose `tf.greaterEqualStrict` which has the same signature as this\n * op and asserts that `a` and `b` are the same shape (does not broadcast).\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([2, 2, 2]);\n *\n * a.greaterEqual(b).print();\n * ```\n *\n * @param a The first input tensor.\n * @param b The second input tensor. Must have the same dtype as `a`.\n */\n/** @doc {heading: 'Operations', subheading: 'Logical'} */\nfunction greaterEqual_(\n a: Tensor|TensorLike, b: Tensor|TensorLike): T {\n let $a = convertToTensor(a, 'a', 'greaterEqual');\n let $b = convertToTensor(b, 'b', 'greaterEqual');\n [$a, $b] = makeTypesMatch($a, $b);\n assertAndGetBroadcastShape($a.shape, $b.shape);\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$a, $b] = saved;\n return {a: () => zerosLike($a), b: () => zerosLike($b)};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.greaterEqual($a, $b);\n save([$a, $b]);\n return res;\n }, {a: $a, b: $b}, grad, 'GreaterEqual') as T;\n}\n\nfunction greaterEqualStrict_(\n a: T|TensorLike, b: T|TensorLike): T {\n const $a = convertToTensor(a, 'a', 'greaterEqualStrict');\n const $b = convertToTensor(b, 'b', 'greaterEqualStrict');\n assertShapesMatch($a.shape, $b.shape, 'Error in greaterEqualStrict: ');\n return $a.greaterEqual($b);\n}\n\nexport const equal = op({equal_});\nexport const equalStrict = op({equalStrict_});\nexport const greater = op({greater_});\nexport const greaterEqual = op({greaterEqual_});\nexport const greaterEqualStrict = op({greaterEqualStrict_});\nexport const greaterStrict = op({greaterStrict_});\nexport const less = op({less_});\nexport const lessEqual = op({lessEqual_});\nexport const lessEqualStrict = op({lessEqualStrict_});\nexport const lessStrict = op({lessStrict_});\nexport const notEqual = op({notEqual_});\nexport const notEqualStrict = op({notEqualStrict_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor1D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assert, isInt, parseAxisParam} from '../util';\nimport {expandDims} from './array_ops';\nimport {getUndoAxesPermutation} from './axis_util';\nimport {maximum} from './binary_ops';\nimport {greaterEqual} from './compare';\nimport {logicalAnd, where} from './logical_ops';\nimport {op} from './operation';\nimport {collectGatherOpShapeInfo} from './segment_util';\nimport {ones, scalar, zerosLike} from './tensor_ops';\n\n/**\n * Computes the sum along segments of a `tf.Tensor`.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');\n * const numSegments = 3;\n *\n * x.unsortedSegmentSum(segmentIds, numSegments).print()\n * //or tf.unsortedSegmentSum(x, segmentIds, numSegments)\n * ```\n * @param x The `tf.Tensor` that will be summed along its segments.\n * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s\n * dimension along the `axis`. Maps each element of `x` to a segment.\n * @param numSegments The number of distinct `segmentIds`.\n */\n/** @doc {heading: 'Operations', subheading: 'Segment'} */\nfunction unsortedSegmentSum_(\n x: T|TensorLike, segmentIds: Tensor1D|TensorLike, numSegments: number): T {\n const $x = convertToTensor(x, 'x', 'unsortedSegmentSum');\n const $segmentIds =\n convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');\n assert(isInt(numSegments), () => 'numSegments must be of dtype int');\n\n const gradFunc = (dy: T, saved: Tensor[]) => {\n const [$segmentIds] = saved;\n const derX = () => {\n return gatherDropNegatives(dy, $segmentIds as Tensor1D);\n };\n return {$x: derX};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.unsortedSegmentSum($x, $segmentIds, numSegments);\n save([$segmentIds]);\n return res;\n }, {$x}, gradFunc) as T;\n}\n\n/**\n * Gather slices from tensor `x`'s axis `axis` according to `indices`.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n * const indices = tf.tensor1d([1, 3, 3], 'int32');\n *\n * x.gather(indices).print();\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n * const indices = tf.tensor1d([1, 1, 0], 'int32');\n *\n * x.gather(indices).print();\n * ```\n * @param x The input tensor whose slices to be gathered.\n * @param indices The indices of the values to extract.\n * @param axis The axis over which to select values. Defaults to 0.\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction gather_(\n x: T|TensorLike, indices: Tensor|TensorLike, axis = 0): T {\n const $x = convertToTensor(x, 'x', 'gather');\n const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');\n axis = parseAxisParam(axis, $x.shape)[0];\n const shapeInfo = collectGatherOpShapeInfo($x, $indices, axis);\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$indices] = saved;\n const derX = () => {\n const paramsShape = $x.shape;\n const indicesSize = $indices.size;\n\n const outerShape = paramsShape.slice(0, axis);\n const outerDims = outerShape.length;\n const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);\n const innerDims = innerShape.length;\n\n const outerAxesIndices = arrayRange(0, outerDims);\n const innerAxesIndices =\n arrayRange(outerDims + 1, outerDims + 1 + innerDims);\n\n const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);\n\n const values = dy.reshape(valuesShape);\n const reshapedIndices = $indices.reshape([indicesSize]);\n\n const transposeDims =\n arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);\n const valuesTranspose = values.transpose(transposeDims);\n let paramsGrad = unsortedSegmentSum(\n valuesTranspose, reshapedIndices as Tensor1D, $x.shape[axis]);\n\n const invertTransposeDims = getUndoAxesPermutation(transposeDims);\n paramsGrad = paramsGrad.transpose(invertTransposeDims);\n\n return paramsGrad as T;\n };\n return {x: derX, indices: () => $indices};\n };\n return (ENGINE.runKernelFunc(\n (backend, save) => {\n const res = backend.gather($x, $indices.flatten(), axis);\n save([$indices]);\n return res;\n },\n {x: $x, indices: $indices}, grad, 'Gather', {axis}))\n .reshape(shapeInfo.outputShape) as T;\n}\n\nfunction arrayRange(start: number, stop: number): number[] {\n const result = [];\n for (let i = start; i < stop; ++i) {\n result.push(i);\n }\n return result;\n}\n\nfunction arrayConcat(arrays: number[][]): number[] {\n const result = [];\n for (let i = 0; i < arrays.length; ++i) {\n for (let j = 0; j < arrays[i].length; ++j) {\n result.push(arrays[i][j]);\n }\n }\n return result;\n}\n\nfunction gatherDropNegatives(x: T, indices: Tensor1D) {\n // Helper function for unsorted segment ops. Gathers params for\n // positive segment ids and gathers 0 for inputs with negative segment id.\n // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py\n const zeroClippedIndices = maximum(indices, zerosLike(indices));\n const gathered = gather(x, zeroClippedIndices as Tensor1D);\n let isPositive = greaterEqual(indices, scalar(0, 'int32'));\n const numIters = gathered.rank - isPositive.rank;\n for (let i = 0; i < numIters; ++i) {\n isPositive = expandDims(isPositive, i + 1);\n }\n isPositive = logicalAnd(isPositive, ones(gathered.shape, 'bool'));\n const zeroSlice = zerosLike(gathered);\n return where(isPositive, gathered, zeroSlice);\n}\n\nexport const gather = op({gather_});\nexport const unsortedSegmentSum = op({unsortedSegmentSum_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {whereAsync} from './logical_ops';\nimport {gather} from './segment_ops';\n\n/**\n * Apply boolean mask to tensor.\n *\n * ```js\n * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);\n * const mask = tf.tensor1d([1, 0, 1], 'bool');\n * const result = await tf.booleanMaskAsync(tensor, mask);\n * result.print();\n * ```\n *\n * @param tensor N-D tensor.\n * @param mask K-D boolean tensor, K <= N and K must be known statically.\n * @param axis A 0-D int Tensor representing the axis in tensor to mask from.\n * By default, axis is 0 which will mask from the first dimension.\n * Otherwise K + axis <= N.\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nasync function booleanMaskAsync_(\n tensor: Tensor|TensorLike, mask: Tensor|TensorLike,\n axis?: number): Promise {\n const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');\n const $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');\n\n const axisFrom = axis == null ? 0 : axis;\n const maskDim = $mask.rank;\n const tensorShape = $tensor.shape;\n\n util.assert(maskDim > 0, () => 'mask cannot be scalar');\n util.assertShapesMatch(\n tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape,\n `mask's shape must match the first K dimensions of tensor's shape,`);\n\n let leadingSize = 1;\n for (let i = axisFrom; i < axisFrom + maskDim; i++) {\n leadingSize *= tensorShape[i];\n }\n const targetTensorShape =\n tensorShape.slice(0, axisFrom)\n .concat([leadingSize], tensorShape.slice(axisFrom + maskDim));\n const reshapedTensor = $tensor.reshape(targetTensorShape);\n const reshapedMask = $mask.reshape([-1]);\n const positivePositions = await whereAsync(reshapedMask);\n const indices = positivePositions.squeeze([1]);\n\n const res = gather(reshapedTensor, indices, axisFrom);\n\n // Ensure no memory leak.\n if (tensor !== $tensor) {\n $tensor.dispose();\n }\n if (mask !== $mask) {\n $mask.dispose();\n }\n indices.dispose();\n reshapedTensor.dispose();\n reshapedMask.dispose();\n positivePositions.dispose();\n\n return res;\n}\n\nexport const booleanMaskAsync = booleanMaskAsync_;\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport * as conv_util from './conv_util';\nimport {op} from './operation';\n\n/**\n * Computes a 1D convolution over the input x.\n *\n * @param x The input tensor, of rank 3 or rank 2, of shape\n * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.\n * @param filter The filter, rank 3, of shape\n * `[filterWidth, inDepth, outDepth]`.\n * @param stride The number of entries by which the filter is moved right at\n * each step.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dataFormat An optional string from \"NWC\", \"NCW\". Defaults to \"NWC\",\n * the data is stored in the order of [batch, in_width, in_channels]. Only\n * \"NWC\" is currently supported.\n * @param dilation The dilation rate in which we sample input values in\n * atrous convolution. Defaults to `1`. If it is greater than 1, then\n * stride must be `1`.\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction conv1d_(\n x: T|TensorLike, filter: Tensor3D|TensorLike, stride: number,\n pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $x = convertToTensor(x, 'x', 'conv1d');\n const $filter = convertToTensor(filter, 'filter', 'conv1d');\n\n let x3D = $x as Tensor3D;\n let reshapedTo3D = false;\n if ($x.rank === 2) {\n reshapedTo3D = true;\n x3D = $x.as3D(1, $x.shape[0], $x.shape[1]);\n }\n\n util.assert(\n x3D.rank === 3,\n () => `Error in conv1d: input must be rank 3, but got rank ${x3D.rank}.`);\n util.assert(\n $filter.rank === 3,\n () => `Error in conv1d: filter must be rank 3, but got rank ` +\n `${$filter.rank}.`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in conv1d: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n util.assert(\n x3D.shape[2] === $filter.shape[1],\n () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +\n `input depth for filter ${$filter.shape[1]}.`);\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(stride, dilation),\n () => 'Error in conv1D: Either stride or dilation must be 1. ' +\n `Got stride ${stride} and dilation '${dilation}'`);\n util.assert(\n dataFormat === 'NWC',\n () => `Error in conv1d: got dataFormat of ${\n dataFormat} but only NWC is currently supported.`);\n\n const filter4D =\n $filter.as4D(1, $filter.shape[0], $filter.shape[1], $filter.shape[2]);\n const input4D = x3D.as4D(x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]);\n const strides: [number, number] = [1, stride];\n const dilations: [number, number] = [1, dilation];\n\n const conv2dDataFormat = 'NHWC';\n\n const res = conv2d(\n input4D, filter4D, strides, pad, conv2dDataFormat, dilations,\n dimRoundingMode);\n\n if (reshapedTo3D) {\n return res.as2D(res.shape[2], res.shape[3]) as T;\n }\n return res.as3D(res.shape[0], res.shape[2], res.shape[3]) as T;\n}\n\n/**\n * Computes a 2D convolution over the input x.\n *\n * @param x The input tensor, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter, rank 4, of shape\n * `[filterHeight, filterWidth, inDepth, outDepth]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels].\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction conv2d_(\n x: T|TensorLike, filter: Tensor4D|TensorLike,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n dilations: [number, number]|number = [1, 1],\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $x = convertToTensor(x, 'x', 'conv2d');\n const $filter = convertToTensor(filter, 'filter', 'conv2d');\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`);\n util.assert(\n $filter.rank === 4,\n () => `Error in conv2d: filter must be rank 4, but got rank ` +\n `${$filter.rank}.`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in conv2d: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];\n util.assert(\n inDepth === $filter.shape[2],\n () => `Error in conv2d: depth of input (${inDepth}) must match ` +\n `input depth for filter ${$filter.shape[2]}.`);\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in conv2D: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n\n const $dataFormat = conv_util.convertConv2DDataFormat(dataFormat);\n const convInfo = conv_util.computeConv2DInfo(\n x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, false,\n $dataFormat);\n\n const grad = (dy: Tensor4D, saved: Tensor[]) => {\n const [$filter, x4D] = saved as [Tensor4D, Tensor4D];\n util.assert(\n conv_util.tupleValuesAreOne(dilations),\n () => 'Error in gradient of conv2D: dilation rates greater than 1 ' +\n `are not yet supported in gradients. Got dilations '${dilations}'`);\n\n return {\n x: () => conv2dDerInput(x4D.shape, dy, $filter, strides, pad, dataFormat),\n filter: () =>\n conv2dDerFilter(x4D, dy, $filter.shape, strides, pad, dataFormat)\n };\n };\n\n const inputsToSave = [$filter, x4D];\n const res = ENGINE.runKernelFunc((backend, save) => {\n const res = backend.conv2d(x4D, $filter, convInfo);\n save([$filter, x4D]);\n\n return res;\n }, {x: x4D, filter: $filter}, grad, 'Conv2D', convInfo, inputsToSave);\n\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the derivative of the input of a 2D convolution.\n *\n * @param xShape The shape of the input: [batch, height, width, inDepth].\n * If length of 3, batch of 1 is assumed.\n * @param dy The derivative of the output, of rank 4 or rank 3 of shape\n * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter, rank 4, of shape\n * `[filterHeight, filterWidth, inDepth, outDepth]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm used:\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels].\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\nfunction conv2dDerInput_(\n xShape: [number, number, number, number]|[number, number, number], dy: T,\n filter: Tensor4D, strides: [number, number]|number,\n pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n util.assert(\n xShape.length === dy.rank,\n () => `Length of inShape ` +\n `(${xShape.length}) and rank of dy (${dy.rank}) must match`);\n\n let xShape4D = xShape as [number, number, number, number];\n let dy4D = dy as Tensor4D;\n let reshapedTo4D = false;\n if (dy.rank === 3) {\n reshapedTo4D = true;\n dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]);\n xShape4D = [1, xShape[0], xShape[1], xShape[2]];\n }\n\n util.assert(\n xShape4D.length === 4,\n () =>\n `Error in conv2dDerInput: inShape must be length 4, but got length ` +\n `${xShape4D.length}.`);\n util.assert(\n dy4D.rank === 4,\n () => `Error in conv2dDerInput: dy must be rank 4, but got ` +\n `rank ${dy4D.rank}`);\n util.assert(\n filter.rank === 4,\n () => `Error in conv2dDerInput: filter must be rank 4, but got ` +\n `rank ${filter.rank}`);\n const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];\n const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];\n util.assert(\n inDepth === filter.shape[2],\n () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` +\n `match input depth for filter ${filter.shape[2]}.`);\n util.assert(\n outDepth === filter.shape[3],\n () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +\n `match output depth for filter ${filter.shape[3]}.`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in conv2dDerInput: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const dilations = 1;\n\n const grad = (ddx: Tensor4D, saved: Tensor[]) => {\n const [filter, dy4D] = saved;\n return {\n dy4D: () => conv2d(\n ddx, filter as Tensor4D, strides, pad, dataFormat, dilations,\n dimRoundingMode),\n filter: () => conv2dDerFilter(\n ddx, dy4D as Tensor4D, (filter as Tensor4D).shape, strides, pad,\n dataFormat, dimRoundingMode)\n };\n };\n\n const $dataFormat = conv_util.convertConv2DDataFormat(dataFormat);\n const convInfo = conv_util.computeConv2DInfo(\n xShape4D, filter.shape, strides, dilations, pad, dimRoundingMode, false,\n $dataFormat);\n const res = ENGINE.runKernelFunc((backend, save) => {\n const res = backend.conv2dDerInput(dy4D, filter, convInfo);\n save([filter, dy4D]);\n return res;\n }, {dy4D, filter}, grad);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the derivative of the filter of a 2D convolution.\n *\n * @param x The input tensor, of rank 4 or rank 3 of shape\n * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.\n * @param dy The dy image, of rank 4 or rank 3, of shape\n * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.\n * @param filterShape The shape of the filter, length 4,\n * [filterHeight, filterWidth, inDepth, outDepth].\n * @param strides The strides of the convolution: [strideHeight,\n * strideWidth].\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n * used in the forward prop of the op.\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels].\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The\n * rounding mode used when computing output dimensions if pad is a\n * number. If none is provided, it will not round and error if the output\n * is of fractional size.\n */\nfunction conv2dDerFilter_(\n x: T, dy: T, filterShape: [number, number, number, number],\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n dimRoundingMode?: 'floor'|'round'|'ceil'): Tensor4D {\n let x4D = x as Tensor4D;\n if (x.rank === 3) {\n x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);\n }\n let dy4D = dy as Tensor4D;\n if (dy4D.rank === 3) {\n dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]);\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +\n `${x4D.shape}.`);\n util.assert(\n dy4D.rank === 4,\n () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +\n `${dy4D.shape}.`);\n util.assert(\n filterShape.length === 4,\n () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +\n `${filterShape}.`);\n const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];\n const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];\n util.assert(\n inDepth === filterShape[2],\n () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +\n `match input depth in filter (${filterShape[2]}.`);\n util.assert(\n outDepth === filterShape[3],\n () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +\n `match output depth for filter (${filterShape[3]}).`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in conv2dDerFilter: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const dilations = 1;\n const $dataFormat = conv_util.convertConv2DDataFormat(dataFormat);\n const convInfo = conv_util.computeConv2DInfo(\n x4D.shape, filterShape, strides, dilations, pad, dimRoundingMode, false,\n $dataFormat);\n return ENGINE.runKernelFunc(\n backend => backend.conv2dDerFilter(x4D, dy4D, convInfo), {x4D, dy4D});\n}\n\n/**\n * Computes the transposed 2D convolution of an image, also known as a\n * deconvolution.\n *\n * @param x The input image, of rank 4 or rank 3, of shape\n * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.\n * @param filter The filter, rank 4, of shape\n * `[filterHeight, filterWidth, outDepth, inDepth]`.\n * `inDepth` must match `inDepth` in `x`.\n * @param outputShape Output shape, of rank 4 or rank 3:\n * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.\n * @param strides The strides of the original convolution:\n * `[strideHeight, strideWidth]`.\n * @param pad The type of padding algorithm used in the non-transpose version\n * of the op.\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction conv2dTranspose_(\n x: T|TensorLike, filter: Tensor4D|TensorLike,\n outputShape: [number, number, number, number]|[number, number, number],\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $x = convertToTensor(x, 'x', 'conv2dTranspose');\n const $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');\n\n return conv2dDerInput_(\n outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);\n}\n\n/**\n * Depthwise 2D convolution.\n *\n * Given a 4D `input` array and a `filter` array of shape\n * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing\n * `inChannels` convolutional filters of depth 1, this op applies a\n * different filter to each input channel (expanding from 1 channel to\n * `channelMultiplier` channels for each), then concatenates the results\n * together. The output has `inChannels * channelMultiplier` channels.\n *\n * See\n * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](\n * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)\n * for more details.\n *\n * @param x The input tensor, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter tensor, rank 4, of shape\n * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`. If strides is a single number, then `strideHeight ==\n * strideWidth`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels]. Only \"NHWC\" is currently supported.\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction depthwiseConv2d_(\n x: T|TensorLike, filter: Tensor4D|TensorLike,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n dilations: [number, number]|number = [1, 1],\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $x = convertToTensor(x, 'x', 'depthwiseConv2d');\n const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in depthwiseConv2d: input must be rank 4, but got ` +\n `rank ${x4D.rank}.`);\n util.assert(\n $filter.rank === 4,\n () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ` +\n `${$filter.rank}.`);\n util.assert(\n x4D.shape[3] === $filter.shape[2],\n () => `Error in depthwiseConv2d: number of input channels ` +\n `(${x4D.shape[3]}) must match the inChannels dimension in ` +\n `filter ${$filter.shape[2]}.`);\n if (dilations == null) {\n dilations = [1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () =>\n 'Error in depthwiseConv2d: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in depthwiseConv2d: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computeConv2DInfo(\n x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode,\n true /* depthwise */);\n\n const grad = (dy: Tensor4D, saved: Tensor[]) => {\n util.assert(\n conv_util.tupleValuesAreOne(dilations),\n () => 'Error in gradient of depthwiseConv2d: dilation rates ' +\n `greater than 1 are not yet supported. Got dilations ` +\n `'${dilations}'`);\n const [x4D, $filter] = saved;\n return {\n x: () => depthwiseConv2dDerInput(\n (x4D as Tensor4D).shape, dy, $filter as Tensor4D, convInfo),\n filter: () => depthwiseConv2dDerFilter(\n x4D as Tensor4D, dy, ($filter as Tensor4D).shape, convInfo),\n };\n };\n\n const inputsToSave = [x4D, $filter];\n const res = ENGINE.runKernelFunc(\n (backend, save) => {\n const res = backend.depthwiseConv2D(x4D, $filter, convInfo);\n save([x4D, $filter]);\n return res;\n },\n {x: x4D, filter: $filter}, grad, 'DepthwiseConv2dNative', convInfo,\n inputsToSave);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * 2-D convolution with separable filters.\n *\n * Performs a depthwise convolution that acts separately on channels followed\n * by a pointwise convolution that mixes channels. Note that this is\n * separability between dimensions [1, 2] and 3, not spatial separability\n * between dimensions 1 and 2.\n *\n * See\n * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](\n * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)\n * for more details.\n *\n * @param x The input tensor, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape\n * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is\n * the filter used in the first step.\n * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape\n * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is\n * the filter used in the second step.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`. If strides is a single number, then `strideHeight ==\n * strideWidth`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels]. Only \"NHWC\" is currently supported.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction separableConv2d_(\n x: T|TensorLike, depthwiseFilter: Tensor4D|TensorLike,\n pointwiseFilter: Tensor4D|TensorLike, strides: [number, number]|number,\n pad: 'valid'|'same', dilation: [number, number]|number = [1, 1],\n dataFormat: 'NHWC'|'NCHW' = 'NHWC'): T {\n const $x = convertToTensor(x, 'x', 'separableConv2d');\n const $depthwiseFilter =\n convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');\n const $pointwiseFilter =\n convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n\n if (dataFormat === 'NCHW') {\n throw new Error(\n 'separableConv2d currently does not support dataFormat NCHW; only ' +\n 'NHWC is supported');\n }\n\n util.assert(\n x4D.rank === 4,\n () => `Error in separableConv2d: input must be rank 4, but got ` +\n `rank ${x4D.rank}.`);\n util.assert(\n $depthwiseFilter.rank === 4,\n () => `Error in separableConv2d: depthwise filter must be rank 4, but ` +\n `got rank ${$depthwiseFilter.rank}.`);\n util.assert(\n $pointwiseFilter.rank === 4,\n () => `Error in separableConv2d: pointwise filter must be rank 4, but ` +\n `got rank ${$depthwiseFilter.rank}.`);\n util.assert(\n $pointwiseFilter.shape[0] === 1,\n () =>\n `Error in separableConv2d: the first dimension of pointwise filter ` +\n ` must be 1, but got ${$pointwiseFilter.shape[0]}.`);\n util.assert(\n $pointwiseFilter.shape[1] === 1,\n () => `Error in separableConv2d: the second dimension of pointwise ` +\n `filter must be 1, but got ${$pointwiseFilter.shape[1]}.`);\n\n const inChannels = $depthwiseFilter.shape[2];\n const channelMultiplier = $depthwiseFilter.shape[3];\n util.assert(\n $pointwiseFilter.shape[2] === inChannels * channelMultiplier,\n () =>\n `Error in separableConv2d: the third dimension of pointwise filter ` +\n `must be ${inChannels * channelMultiplier}, ` +\n `but got ${$pointwiseFilter.shape[2]}.`);\n\n const depthwise = depthwiseConv2d(\n x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);\n const pointwiseStride = 1;\n const res =\n conv2d(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\nfunction parseTupleParam(\n param: number|[number, number]|[number, number, number]):\n [number, number, number] {\n if (typeof param === 'number') {\n return [param, param, param];\n }\n if (param.length === 2) {\n return [param[0], param[1], 1];\n }\n return param;\n}\n\nfunction tupleValuesAreOne(\n param: number|[number, number]|[number, number, number]): boolean {\n const [dimA, dimB, dimC] = parseTupleParam(param);\n return dimA === 1 && dimB === 1 && dimC === 1;\n}\n\nfunction eitherStridesOrDilationsAreOne(\n strides: number|[number, number]|[number, number, number],\n dilations: number|[number, number]|[number, number, number]): boolean {\n return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);\n}\n\nfunction depthwiseConv2dDerInput_(\n xShape: [number, number, number, number]|[number, number, number], dy: T,\n filter: Tensor4D, convInfo: conv_util.Conv2DInfo): T {\n let dy4D = dy as Tensor4D;\n let reshapedTo4D = false;\n if (dy.rank === 3) {\n reshapedTo4D = true;\n dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]);\n }\n const res = ENGINE.runKernelFunc(\n backend => backend.depthwiseConv2DDerInput(dy4D, filter, convInfo),\n {dy4D});\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\nfunction depthwiseConv2dDerFilter_(\n x: T, dy: T, filterShape: [number, number, number, number],\n convInfo: conv_util.Conv2DInfo): Tensor4D {\n let x4D = x as Tensor4D;\n if (x.rank === 3) {\n x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);\n }\n let dy4D = dy as Tensor4D;\n if (dy4D.rank === 3) {\n dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]);\n }\n return ENGINE.runKernelFunc(\n backend => backend.depthwiseConv2DDerFilter(x4D, dy4D, convInfo),\n {x4D, dy4D});\n}\n\n/**\n * Computes a 3D convolution over the input x.\n *\n * @param x The input tensor, of rank 5 or rank 4, of shape\n * `[batch, depth, height, width, channels]`. If rank 4,\n * batch of 1 is assumed.\n * @param filter The filter, rank 5, of shape\n * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.\n * inChannels must match between input and filter.\n * @param strides The strides of the convolution: `[strideDepth, strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dataFormat: An optional string from: \"NDHWC\", \"NCDHW\". Defaults to\n * \"NDHWC\". Specify the data format of the input and output data. With the\n * default format \"NDHWC\", the data is stored in the order of: [batch,\n * depth, height, width, channels]. Only \"NDHWC\" is currently supported.\n * @param dilations The dilation rates: `[dilationDepth, dilationHeight,\n * dilationWidth]` in which we sample input values across the height\n * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.\n * If `dilations` is a single number, then\n * `dilationDepth == dilationHeight == dilationWidth`. If it is greater\n * than 1, then all values of `strides` must be 1.\n */\n\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction conv3d_(\n x: T|TensorLike, filter: Tensor5D|TensorLike,\n strides: [number, number, number]|number, pad: 'valid'|'same',\n dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC',\n dilations: [number, number, number]|number = [1, 1, 1]): T {\n const $x = convertToTensor(x, 'x', 'conv3d');\n const $filter = convertToTensor(filter, 'filter', 'conv3d');\n\n let x5D = $x as Tensor5D;\n let reshapedTo5D = false;\n\n if ($x.rank === 4) {\n reshapedTo5D = true;\n x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]);\n }\n util.assert(\n x5D.rank === 5,\n () => `Error in conv3d: input must be rank 5, but got rank ${x5D.rank}.`);\n util.assert(\n $filter.rank === 5,\n () => `Error in conv3d: filter must be rank 5, but got rank ` +\n `${$filter.rank}.`);\n util.assert(\n x5D.shape[4] === $filter.shape[3],\n () => `Error in conv3d: depth of input (${x5D.shape[4]}) must match ` +\n `input depth for filter ${$filter.shape[3]}.`);\n util.assert(\n eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in conv3D: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n util.assert(\n dataFormat === 'NDHWC',\n () => `Error in conv3d: got dataFormat of ${\n dataFormat} but only NDHWC is currently supported.`);\n\n const convInfo = conv_util.computeConv3DInfo(\n x5D.shape, $filter.shape, strides, dilations, pad);\n\n const grad = (dy: Tensor5D, saved: Tensor[]) => {\n util.assert(\n tupleValuesAreOne(dilations),\n () =>\n 'Error in gradient of conv3D: dilation rates greater than 1 are ' +\n `not yet supported in gradients. Got dilations '${dilations}'`);\n const [x5D, $filter] = saved;\n return {\n x: () => conv3dDerInput_(\n (x5D as Tensor5D).shape, dy, $filter as Tensor5D, strides, pad),\n $filter: () => conv3dDerFilter_(\n x5D as Tensor5D, dy, ($filter as Tensor5D).shape, strides, pad)\n };\n };\n\n const res = ENGINE.runKernelFunc((backend, save) => {\n const res = backend.conv3d(x5D, $filter, convInfo);\n save([x5D, $filter]);\n return res;\n }, {x: x5D, $filter}, grad);\n if (reshapedTo5D) {\n return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as\n T;\n }\n return res as T;\n}\n\n/**\n * Computes the derivative of the input of a 3D convolution.\n *\n * @param xShape The shape of the input: [batch, depth, height, width,\n * in_channels]. If length of 4, batch of 1 is assumed.\n * @param dy The derivative of the output, of rank 5 or rank 4 of shape\n * `[batch, outDepth, outHeight, outWidth, in_channels]`.\n * If rank 4, batch of 1 is assumed.\n * @param filter The filter, rank 5, of shape\n * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.\n * @param strides The strides of the convolution: `[strideDepth, strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm used:\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n */\nfunction conv3dDerInput_(\n xShape:\n [number, number, number, number,\n number]|[number, number, number, number],\n dy: T, filter: Tensor5D, strides: [number, number, number]|number,\n pad: 'valid'|'same'): T {\n util.assert(\n xShape.length === dy.rank,\n () => `Length of inShape ` +\n `(${xShape.length}) and rank of dy (${dy.rank}) must match`);\n\n let xShape5D = xShape as [number, number, number, number, number];\n let dy5D = dy as Tensor5D;\n let reshapedTo5D = false;\n if (dy.rank === 4) {\n reshapedTo5D = true;\n dy5D = dy.as5D(1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]);\n xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];\n }\n\n const inDepth = xShape5D[4];\n const outDepth = dy5D.shape[4];\n util.assert(\n xShape5D.length === 5,\n () =>\n `Error in conv3dDerInput: inShape must be length 5, but got length ` +\n `${xShape5D.length}.`);\n util.assert(\n dy5D.rank === 5,\n () => `Error in conv3dDerInput: dy must be rank 5, but got ` +\n `rank ${dy5D.rank}`);\n util.assert(\n filter.rank === 5,\n () => `Error in conv3dDerInput: filter must be rank 5, but got ` +\n `rank ${filter.rank}`);\n util.assert(\n inDepth === filter.shape[3],\n () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +\n `match input depth for filter ${filter.shape[3]}.`);\n util.assert(\n outDepth === filter.shape[4],\n () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +\n `match output depth for filter ${filter.shape[4]}.`);\n\n const dilations = 1;\n\n const convInfo = conv_util.computeConv3DInfo(\n xShape5D, filter.shape, strides, dilations, pad);\n const res = ENGINE.runKernelFunc(\n backend => backend.conv3dDerInput(dy5D, filter, convInfo), {dy5D});\n if (reshapedTo5D) {\n return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as\n T;\n }\n return res as T;\n}\n\n/**\n * Computes the derivative of the filter of a 3D convolution.\n *\n * @param x The input tensor, of rank 5 or rank 4 of shape\n * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is\n * assumed.\n * @param dy The dy image, of rank 5 or rank 4, of shape\n * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is\n * assumed.\n * @param filterShape The shape of the filter, length 5,\n * [filterDepth, filterHeight, filterWidth, inDepth, outDepth].\n * @param strides The strides of the convolution: [strideDepth, strideHeight,\n * strideWidth].\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n * used in the forward prop of the op.\n */\nfunction conv3dDerFilter_(\n x: T, dy: T, filterShape: [number, number, number, number, number],\n strides: [number, number, number]|number, pad: 'valid'|'same'): Tensor5D {\n let x5D = x as Tensor5D;\n if (x.rank === 4) {\n x5D = x.as5D(1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]);\n }\n let dy5D = dy as Tensor5D;\n if (dy5D.rank === 4) {\n dy5D = dy.as5D(1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]);\n }\n util.assert(\n x5D.rank === 5,\n () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` +\n `${x5D.shape}.`);\n util.assert(\n dy5D.rank === 5,\n () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` +\n `${dy5D.shape}.`);\n util.assert(\n filterShape.length === 5,\n () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` +\n `${filterShape}.`);\n util.assert(\n x5D.shape[4] === filterShape[3],\n () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` +\n `match input depth in filter (${filterShape[3]}.`);\n util.assert(\n dy5D.shape[4] === filterShape[4],\n () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` +\n `match output depth for filter (${filterShape[4]}).`);\n\n const dilations = 1;\n\n const convInfo = conv_util.computeConv3DInfo(\n x5D.shape, filterShape, strides, dilations, pad);\n return ENGINE.runKernelFunc(\n backend => backend.conv3dDerFilter(x5D, dy5D, convInfo), {x5D, dy5D});\n}\n\n/**\n * Computes the transposed 3D convolution of a volume, also known as a\n * deconvolution.\n *\n * @param x The input image, of rank 5 or rank 4, of shape\n * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.\n * @param filter The filter, rank 4, of shape\n * `[depth, filterHeight, filterWidth, outDepth, inDepth]`.\n * `inDepth` must match `inDepth` in `x`.\n * @param outputShape Output shape, of rank 5 or rank 4:\n * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is\n * assumed.\n * @param strides The strides of the original convolution:\n * `[strideDepth, strideHeight, strideWidth]`.\n * @param pad The type of padding algorithm used in the non-transpose version\n * of the op.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction conv3dTranspose_(\n x: T|TensorLike, filter: Tensor5D|TensorLike,\n outputShape:\n [number, number, number, number,\n number]|[number, number, number, number],\n strides: [number, number, number]|number, pad: 'valid'|'same'): T {\n const $x = convertToTensor(x, 'x', 'conv3dTranspose');\n const $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');\n\n return conv3dDerInput_(outputShape, $x, $filter, strides, pad);\n}\n\nexport const conv1d = op({conv1d_});\nexport const conv2d = op({conv2d_});\nexport const conv3d = op({conv3d_});\nexport const conv2dDerFilter = op({conv2dDerFilter_});\nexport const conv2dDerInput = op({conv2dDerInput_});\nexport const depthwiseConv2d = op({depthwiseConv2d_});\nexport const depthwiseConv2dDerInput = op({depthwiseConv2dDerInput_});\nexport const depthwiseConv2dDerFilter = op({depthwiseConv2dDerFilter_});\nexport const separableConv2d = op({separableConv2d_});\nexport const conv2dTranspose = op({conv2dTranspose_});\nexport const conv3dTranspose = op({conv3dTranspose_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor1D, Tensor2D, Tensor3D} from '../tensor';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport {op} from './operation';\n\n/**\n * Computes the dot product of two matrices, A * B. These must be matrices.\n *\n * ```js\n * const a = tf.tensor2d([1, 2], [1, 2]);\n * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * a.matMul(b).print(); // or tf.matMul(a, b)\n * ```\n * @param a First matrix in dot product operation.\n * @param b Second matrix in dot product operation.\n * @param transposeA If true, `a` is transposed before multiplication.\n * @param transposeB If true, `b` is transposed before multiplication.\n */\n/** @doc {heading: 'Operations', subheading: 'Matrices'} */\nfunction matMul_(\n a: T|TensorLike, b: T|TensorLike, transposeA = false,\n transposeB = false): T {\n let $a = convertToTensor(a, 'a', 'matMul');\n let $b = convertToTensor(b, 'b', 'matMul');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const innerShapeA =\n transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];\n const innerShapeB =\n transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];\n\n const outerShapeA =\n transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];\n const outerShapeB =\n transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];\n\n const outerDimsA = $a.shape.slice(0, -2);\n const outerDimsB = $b.shape.slice(0, -2);\n const batchDimA = util.sizeFromShape(outerDimsA);\n const batchDimB = util.sizeFromShape(outerDimsB);\n\n util.assert(\n $a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank,\n () => `Error in matMul: inputs must have the same rank of at least 2, ` +\n `got ranks ${$a.rank} and ${$b.rank}.`);\n\n util.assert(\n util.arraysEqual(outerDimsA, outerDimsB),\n () => `Error in matMul: outer dimensions (${outerDimsA}) and (` +\n `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` +\n `${$b.shape} must match.`);\n\n util.assert(\n innerShapeA === innerShapeB,\n () => `Error in matMul: inner shapes (${innerShapeA}) and (` +\n `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +\n `${$b.shape} and transposeA=${transposeA}` +\n ` and transposeB=${transposeB} must match.`);\n\n const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);\n\n const a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) :\n $a.as3D(batchDimA, outerShapeA, innerShapeA);\n const b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) :\n $b.as3D(batchDimB, innerShapeB, outerShapeB);\n\n const grad = (dy: Tensor3D, saved: Tensor[]) => {\n const [a3D, b3D] = saved as Tensor3D[];\n if (!transposeA && !transposeB) {\n return {\n a: () => dy.matMul(b3D, false, true),\n b: () => a3D.matMul(dy, true, false)\n };\n } else if (!transposeA && transposeB) {\n return {\n a: () => dy.matMul(b3D, false, false),\n b: () => dy.matMul(a3D, true, false)\n };\n } else if (transposeA && !transposeB) {\n return {\n a: () => b3D.matMul(dy, false, true),\n b: () => a3D.matMul(dy, false, false)\n };\n } else {\n return {\n a: () => b3D.matMul(dy, true, true),\n b: () => dy.matMul(a3D, true, true)\n };\n }\n };\n\n const attrs = {transposeA, transposeB};\n const res = ENGINE.runKernelFunc((backend, save) => {\n const res = backend.batchMatMul(a3D, b3D, transposeA, transposeB);\n save([a3D, b3D]);\n return res;\n }, {a: a3D, b: b3D}, grad, 'BatchMatMul', attrs);\n return res.reshape(outShape) as T;\n}\n\n/**\n * Computes the outer product of two vectors, `v1` and `v2`.\n *\n * ```js\n * const a = tf.tensor1d([1, 2, 3]);\n * const b = tf.tensor1d([3, 4, 5]);\n *\n * tf.outerProduct(a, b).print();\n * ```\n * @param v1 The first vector in the outer product operation.\n * @param v2 The second vector in the outer product operation.\n */\n/** @doc {heading: 'Operations', subheading: 'Matrices'} */\nfunction outerProduct_(\n v1: Tensor1D|TensorLike, v2: Tensor1D|TensorLike): Tensor2D {\n const $v1 = convertToTensor(v1, 'v1', 'outerProduct');\n const $v2 = convertToTensor(v2, 'v2', 'outerProduct');\n\n util.assert(\n $v1.rank === 1 && $v2.rank === 1,\n () => `Error in outerProduct: inputs must be rank 1, but got ranks ` +\n `${$v1.rank} and ${$v2.rank}.`);\n\n return $v1.as2D(-1, 1).matMul($v2.as2D(1, -1));\n}\n\n/**\n * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.\n *\n * ```js\n * const a = tf.tensor1d([1, 2]);\n * const b = tf.tensor2d([[1, 2], [3, 4]]);\n * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);\n *\n * a.dot(b).print(); // or tf.dot(a, b)\n * b.dot(a).print();\n * b.dot(c).print();\n * ```\n * @param t1 The first tensor in the dot operation.\n * @param t2 The second tensor in the dot operation.\n */\n/** @doc {heading: 'Operations', subheading: 'Matrices'} */\nfunction dot_(t1: Tensor|TensorLike, t2: Tensor|TensorLike): Tensor {\n const $t1 = convertToTensor(t1, 't1', 'dot');\n const $t2 = convertToTensor(t2, 't2', 'dot');\n util.assert(\n ($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2),\n () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +\n `${$t1.rank} and ${$t2.rank}.`);\n\n const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);\n const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);\n\n util.assert(\n t1Inner === t2Inner,\n () => `Error in dot: inner dimensions of inputs must match, but got ` +\n `${t1Inner} and ${t2Inner}.`);\n\n if ($t1.rank === 1 && $t2.rank === 1) {\n return $t1.as2D(1, -1).matMul($t2.as2D(-1, 1)).asScalar();\n } else if ($t1.rank === 1 && $t2.rank === 2) {\n return $t1.as2D(1, -1).matMul($t2.as2D($t2.shape[0], $t2.shape[1])).as1D();\n } else if ($t1.rank === 2 && $t2.rank === 1) {\n return $t1.matMul($t2.as2D(-1, 1)).as1D();\n } else {\n return $t1.matMul($t2.as2D($t2.shape[0], $t2.shape[1]));\n }\n}\n\nexport const matMul = op({matMul_});\nexport const dot = op({dot_});\nexport const outerProduct = op({outerProduct_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport {op} from './operation';\n\n/**\n * Reverses a `tf.Tensor1D`.\n *\n * @param x The input tensor.\n */\nfunction reverse1d_(x: Tensor1D|TensorLike): Tensor1D {\n const $x = convertToTensor(x, 'x', 'reverse');\n util.assert(\n $x.rank === 1,\n () => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`);\n return reverse($x, 0);\n}\n\n/**\n * Reverses a `tf.Tensor2D` along a specified axis.\n *\n * @param x The input tensor.\n * @param axis The set of dimensions to reverse. Must be in the\n * range [-rank(x), rank(x)). Defaults to all axes.\n */\nfunction reverse2d_(x: Tensor2D|TensorLike, axis?: number|number[]): Tensor2D {\n const $x = convertToTensor(x, 'x', 'reverse');\n util.assert(\n $x.rank === 2,\n () => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`);\n return reverse($x, axis);\n}\n\n/**\n * Reverses a `tf.Tensor3D` along a specified axis.\n *\n * @param x The input tensor.\n * @param axis The set of dimensions to reverse. Must be in the\n * range [-rank(x), rank(x)). Defaults to all axes.\n */\nfunction reverse3d_(x: Tensor3D|TensorLike, axis?: number|number[]): Tensor3D {\n const $x = convertToTensor(x, 'x', 'reverse');\n util.assert(\n $x.rank === 3,\n () => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`);\n return reverse($x, axis);\n}\n\n/**\n * Reverses a `tf.Tensor4D` along a specified axis.\n *\n * @param x The input tensor.\n * @param axis The set of dimensions to reverse. Must be in the\n * range [-rank(x), rank(x)). Defaults to all axes.\n */\nfunction reverse4d_(x: Tensor4D|TensorLike, axis?: number|number[]): Tensor4D {\n const $x = convertToTensor(x, 'x', 'reverse');\n util.assert(\n $x.rank === 4,\n () => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`);\n return reverse($x, axis);\n}\n\n/**\n * Reverses a `tf.Tensor` along a specified axis.\n *\n * Also available are stricter rank-specific methods that assert that `x` is\n * of the given rank:\n * - `tf.reverse1d`\n * - `tf.reverse2d`\n * - `tf.reverse3d`\n * - `tf.reverse4d`\n *\n * Except `tf.reverse1d` (which does not have axis param), all methods have\n * same signature as this method.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n *\n * x.reverse().print();\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.reverse(axis).print();\n * ```\n * @param x The input tensor to be reversed.\n * @param axis The set of dimensions to reverse. Must be in the\n * range [-rank(x), rank(x)). Defaults to all axes.\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction reverse_(\n x: T|TensorLike, axis?: number|number[]): T {\n const $x = convertToTensor(x, 'x', 'reverse');\n\n if ($x.rank === 0) {\n return $x.clone();\n }\n const axes = util.parseAxisParam(axis, $x.shape);\n const grad = (dy: T) => {\n return {$x: () => dy.reverse(axes)};\n };\n const res =\n ENGINE.runKernelFunc(backend => backend.reverse($x, axes), {$x}, grad);\n return res.reshapeAs($x);\n}\n\nexport const reverse = op({reverse_});\nexport const reverse1d = op({reverse1d_});\nexport const reverse2d = op({reverse2d_});\nexport const reverse3d = op({reverse3d_});\nexport const reverse4d = op({reverse4d_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor3D, Tensor4D, Tensor5D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {batchToSpaceND, spaceToBatchND} from './array_ops';\nimport * as conv_util from './conv_util';\nimport {op} from './operation';\n\n/**\n * Computes the 2D max pooling of an image.\n *\n * @param x The input tensor, of rank 4 or rank 3 of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\nfunction maxPoolImpl_(\n x: T|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number, dilations: [number, number]|number,\n pad: 'valid'|'same'|number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $x = convertToTensor(x, 'x', 'maxPool');\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n if (dilations == null) {\n dilations = [1, 1];\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in maxPool: input must be rank 4 but got rank ${x4D.rank}.`);\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in maxPool: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in maxPool: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n const convInfo = conv_util.computePool2DInfo(\n x4D.shape, filterSize, strides, dilations, pad, dimRoundingMode);\n if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&\n util.arraysEqual(convInfo.inShape, convInfo.outShape)) {\n return $x.clone();\n }\n\n const grad = (dy: Tensor4D, saved: Tensor[]) => {\n const [x4D, y] = saved;\n return {\n x: () => maxPoolBackprop(\n dy, x4D as Tensor4D, y as Tensor4D, filterSize, strides, dilations,\n pad)\n };\n };\n\n const inputsToSave = [x4D];\n const res = ENGINE.runKernelFunc((backend, save) => {\n const y = backend.maxPool(x4D, convInfo);\n save([x4D, y]);\n return y;\n }, {x: x4D}, grad, 'MaxPool', convInfo, inputsToSave);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the 2D max pooling of an image.\n *\n * @param x The input tensor, of rank 4 or rank 3 of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction maxPool_(\n x: T|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n return maxPoolImpl_(x, filterSize, strides, 1, pad, dimRoundingMode);\n}\n\n/**\n * Computes the 2D average pooling of an image.\n *\n * @param x The input tensor, of rank 4 or rank 3 of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param pad The type of padding algorithm:\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\nfunction avgPoolImpl_(\n x: T|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number, dilations: [number, number]|number,\n pad: 'valid'|'same'|number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $x = convertToTensor(x, 'x', 'avgPool', 'float32');\n if (dilations == null) {\n dilations = [1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in avgPool: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in avgPool: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computePool2DInfo(\n x4D.shape, filterSize, strides, dilations, pad, dimRoundingMode);\n if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&\n util.arraysEqual(convInfo.inShape, convInfo.outShape)) {\n return $x.clone();\n }\n\n const grad = (dy: Tensor4D) => {\n return {\n x: () => avgPoolBackprop(dy, x4D, filterSize, strides, dilations, pad)\n };\n };\n\n let res = ENGINE.runKernelFunc(\n backend => backend.avgPool(x4D, convInfo), {x: x4D}, grad, 'AvgPool',\n convInfo);\n res = res.cast($x.dtype);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the 2D average pooling of an image.\n *\n * @param x The input tensor, of rank 4 or rank 3 of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param pad The type of padding algorithm:\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction avgPool_(\n x: T|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n return avgPoolImpl_(x, filterSize, strides, 1, pad, dimRoundingMode);\n}\n\n/**\n * Performs an N-D pooling operation\n *\n * @param input The input tensor, of rank 4 or rank 3 of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param windowShape The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param poolingType The type of pooling, either 'max' or 'avg'.\n * @param pad The type of padding algorithm:\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction pool_(\n input: T|TensorLike, windowShape: [number, number]|number,\n poolingType: 'avg'|'max', pad: 'valid'|'same'|number,\n dilations?: [number, number]|number, strides?: [number, number]|number) {\n if (dilations == null) {\n dilations = [1, 1];\n }\n if (strides == null) {\n strides = 1;\n }\n if (pad === 0) {\n pad = 'valid';\n }\n const $x = convertToTensor(input, 'x', 'maxPool');\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in pool: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n const convInfo = conv_util.computePool2DInfo(\n x4D.shape, windowShape, strides, dilations, pad);\n const dilation: [number, number] =\n [convInfo.dilationHeight, convInfo.dilationWidth];\n\n // The following implementation does batchToSpace(pool(spaceToBatch(x)))\n // whenever dilation > 1 since the TF kernels do not support dilation > 1.\n // tslint:disable-next-line:max-line-length\n // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037\n\n let basePadding: number[][];\n if (pad === 'same') {\n basePadding = withSpaceToBatchBasePaddings(\n [convInfo.filterHeight, convInfo.filterWidth], dilation);\n } else {\n basePadding = [[0, 0], [0, 0]];\n }\n const isDilationOne = dilation[0] === 1 && dilation[1] === 1;\n const [adjustedPadding, adjustedCrops] = requiredSpaceToBatchPaddings(\n [convInfo.inHeight, convInfo.inWidth], dilation, basePadding);\n const convertedPad = isDilationOne ? pad : 'valid';\n const convertedX =\n isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding);\n const forwardOp = poolingType === 'avg' ?\n () => avgPoolImpl_(\n convertedX, windowShape, strides, 1 /* dilation */, convertedPad) :\n () => maxPoolImpl_(\n convertedX, windowShape, strides, 1 /* dilation */, convertedPad);\n const y = forwardOp();\n const res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the backprop of a 2D max pool.\n *\n * @param dy The dy error, of rank 4 or rank 3 of shape\n * [batchSize, height, width, channels]. If rank 3, batch of 1 is\n * assumed.\n * @param input The original input image, of rank 4, of shape\n * [batchSize, height, width, channels].\n * @param output The original output image, of rank 4, of shape\n * [batchSize, outHeight, outWidth, channels].\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n * used in the forward prop of the op.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The\n * rounding mode used when computing output dimensions if pad is a\n * number. If none is provided, it will not round and error if the output\n * is of fractional size.\n */\nfunction maxPoolBackprop(\n dy: Tensor4D|TensorLike, input: Tensor4D|TensorLike,\n output: Tensor4D|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number, dilations: [number, number]|number,\n pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): Tensor4D {\n const $dy = convertToTensor(dy, 'dy', 'maxPoolBackprop');\n const $input = convertToTensor(input, 'input', 'maxPoolBackprop');\n const $output = convertToTensor(output, 'output', 'maxPoolBackprop');\n util.assert(\n $input.rank === $dy.rank,\n () => `Rank of input (${$input.rank}) does not match rank of dy ` +\n `(${$dy.rank})`);\n if (dilations == null) {\n dilations = [1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () =>\n 'Error in maxPoolBackProp: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n\n util.assert(\n $dy.rank === 4,\n () => `Error in maxPoolBackprop: dy must be rank 4 but got rank ` +\n `${$dy.rank}.`);\n util.assert(\n $input.rank === 4,\n () => `Error in maxPoolBackprop: input must be rank 4 but got rank ` +\n `${$input.rank}.`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in maxPoolBackprop: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computePool2DInfo(\n $input.shape, filterSize, strides, dilations, pad, dimRoundingMode);\n const res = ENGINE.runKernelFunc(\n backend => backend.maxPoolBackprop($dy, $input, $output, convInfo),\n {$dy, $input});\n return res;\n}\n\n/**\n * Computes the backprop of an 2D avg pool.\n *\n * @param dy The dy error, of rank 4 or rank 3 of shape\n * [batchSize, height, width, channels]. If rank 3, batch of 1 is\n * assumed.\n * @param input The input image, of rank 4 or rank 3 of shape\n * [batchSize, height, width, channels]. If rank 3, batch of 1 is\n * assumed.\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n * used in the forward prop of the op.\n */\nfunction avgPoolBackprop(\n dy: T|TensorLike, input: T|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number, dilations: [number, number]|number,\n pad: 'valid'|'same'|number): T {\n const $dy = convertToTensor(dy, 'dy', 'avgPoolBackprop');\n const $input = convertToTensor(input, 'input', 'avgPoolBackprop');\n util.assert(\n $input.rank === $dy.rank,\n () => `Rank of input (${$input.rank}) does not match rank of dy (${\n $dy.rank})`);\n if (dilations == null) {\n dilations = [1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () =>\n 'Error in avgPoolBackprop: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n\n let input4D = $input as Tensor4D;\n let dy4D = $dy as Tensor4D;\n let reshapedTo4D = false;\n if ($input.rank === 3) {\n reshapedTo4D = true;\n input4D = $input.as4D(1, $input.shape[0], $input.shape[1], $input.shape[2]);\n dy4D = $dy.as4D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2]);\n }\n\n util.assert(\n dy4D.rank === 4,\n () => `Error in avgPoolBackprop: dy must be rank 4 but got rank ` +\n `${dy4D.rank}.`);\n util.assert(\n input4D.rank === 4,\n () => `Error in avgPoolBackprop: input must be rank 4 but got rank ` +\n `${input4D.rank}.`);\n\n const convInfo = conv_util.computePool2DInfo(\n input4D.shape, filterSize, strides, dilations, pad);\n const res = ENGINE.runKernelFunc(\n backend => backend.avgPoolBackprop(dy4D, input4D, convInfo),\n {dy4D, input4D});\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n// Helper function to compute crops and paddings for pool with dilation > 1.\n// tslint:disable-next-line:max-line-length\n// https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184\nfunction requiredSpaceToBatchPaddings(\n inputShape: [number, number], blockShape: [number, number],\n basePadding: number[][]) {\n const padStart = basePadding.map(b => b[0]);\n const origPadEnd = basePadding.map(b => b[1]);\n const fullInputShape = inputShape.concat(padStart, origPadEnd);\n const padEndExtra = blockShape.map((b, i) => (b - fullInputShape[i] % b) % b);\n const padEnd = origPadEnd.map((s, i) => s + padEndExtra[i]);\n const paddings = blockShape.map((_, i) => [padStart[i], padEnd[i]]);\n const crops = blockShape.map((_, i) => [0, padEndExtra[i]]);\n return [paddings, crops];\n}\n\n// Helper function to compute base paddings for pool with dilation > 1.\n// tslint:disable-next-line:max-line-length\n// https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524\nfunction withSpaceToBatchBasePaddings(\n filterShape: [number, number], dilation: [number, number]) {\n // Spatial dimensions of the filters and the upsampled filters in which we\n // introduce (rate - 1) zeros between consecutive filter values.\n const dilatedFilterShape = filterShape.map((s, i) => {\n return s + (s - 1) * (dilation[i] - 1);\n });\n const padExtraShape = dilatedFilterShape.map(s => s - 1);\n\n // When padding is odd, we pad more at end, following the same\n // convention as conv2d.\n const padExtraStart = padExtraShape.map(s => Math.floor(s / 2));\n const padExtraEnd = padExtraShape.map((s, i) => s - padExtraStart[i]);\n return padExtraShape.map((_, i) => {\n return [padExtraStart[i], padExtraEnd[i]];\n });\n}\n\n/**\n * Computes the 3D average pooling.\n *\n * ```js\n * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);\n * const result = tf.avgPool3d(x, 2, 1, 'valid');\n * result.print();\n * ```\n *\n * @param x The input tensor, of rank 5 or rank 4 of shape\n * `[batch, depth, height, width, inChannels]`.\n * @param filterSize The filter size:\n * `[filterDepth, filterHeight, filterWidth]`.\n * If `filterSize` is a single number,\n * then `filterDepth == filterHeight == filterWidth`.\n * @param strides The strides of the pooling:\n * `[strideDepth, strideHeight, strideWidth]`.\n * If `strides` is a single number,\n * then `strideDepth == strideHeight == strideWidth`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1*1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n * @param dataFormat An optional string from: \"NDHWC\", \"NCDHW\". Defaults to\n * \"NDHWC\". Specify the data format of the input and output data. With the\n * default format \"NDHWC\", the data is stored in the order of: [batch,\n * depth, height, width, channels]. Only \"NDHWC\" is currently supported.\n * @param dilations The dilation rates:\n * `[dilationDepth, dilationHeight, dilationWidth]`\n * in which we sample input values across the depth, height and width\n * dimensions in dilated pooling.\n * Defaults to `[1, 1, 1]`. If `dilations` is a single number,\n * then `dilationDepth == dilationHeight == dilationWidth`.\n * If it is greater than 1, then all values of `strides` must be 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction avgPool3d_(\n x: T|TensorLike,\n filterSize: [number, number, number]|number,\n strides: [number, number, number]|number,\n pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil',\n dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC',\n dilations?: [number, number, number]|number,\n ): T {\n const $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');\n\n let x5D = $x as Tensor5D;\n let reshapedTo5D = false;\n if ($x.rank === 4) {\n reshapedTo5D = true;\n x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]);\n }\n\n if (dilations == null) {\n dilations = [1, 1, 1];\n }\n util.assert(\n x5D.rank === 5,\n () => `Error in avgPool3d: x must be rank 5 but got rank ${x5D.rank}.`);\n util.assert(\n dataFormat === 'NDHWC',\n () => `Error in avgPool3d: Only NDHWC is currently supported, ` +\n `but got dataFormat of ${dataFormat}`);\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in avgPool3d: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in avgPool3d: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computePool3DInfo(\n x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode,\n dataFormat);\n\n const grad = (dy: Tensor5D) => {\n return {\n x: () => avgPool3dBackprop(\n dy, x5D, filterSize, strides, dilations, pad, dimRoundingMode)\n };\n };\n\n let res = ENGINE.runKernelFunc(\n backend => backend.avgPool3d(x5D, convInfo), {x: x5D}, grad);\n res = res.cast(x5D.dtype);\n if (reshapedTo5D) {\n return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as\n T;\n }\n\n return res as T;\n}\n\n/**\n * Computes the backprop of a 3d avg pool.\n *\n * @param dy The dy error, of rank 5 of shape\n * [batchSize, depth, height, width, channels].\n * assumed.\n * @param input The original input image, of rank 5 or rank4 of shape\n * [batchSize, depth, height, width, channels].\n * @param filterSize The filter size:\n * `[filterDepth, filterHeight, filterWidth]`.\n * `filterSize` is a single number,\n * then `filterDepth == filterHeight == filterWidth`.\n * @param strides The strides of the pooling:\n * `[strideDepth, strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param dilations The dilation rates:\n * `[dilationDepth, dilationHeight, dilationWidth]`\n * in which we sample input values across the depth, height and width\n * dimensions in dilated pooling.\n * Defaults to `[1, 1, 1]`. If `dilations` is a single number,\n * then `dilationDepth == dilationHeight == dilationWidth`.\n * If it is greater than 1, then all values of `strides` must be 1.\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n * used in the forward prop of the op.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The\n * rounding mode used when computing output dimensions if pad is a\n * number. If none is provided, it will not round and error if the output\n * is of fractional size.\n */\nfunction avgPool3dBackprop(\n dy: T|TensorLike, input: T|TensorLike,\n filterSize: [number, number, number]|number,\n strides: [number, number, number]|number,\n dilations: [number, number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $dy = convertToTensor(dy, 'dy', 'avgPool3dBackprop');\n const $input = convertToTensor(input, 'input', 'avgPool3dBackprop');\n\n let dy5D = $dy as Tensor5D;\n let input5D = $input as Tensor5D;\n let reshapedTo5D = false;\n if ($input.rank === 4) {\n reshapedTo5D = true;\n dy5D = $dy.as5D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]);\n input5D = $input.as5D(\n 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]);\n }\n\n util.assert(\n dy5D.rank === 5,\n () => `Error in avgPool3dBackprop: dy must be rank 5 but got rank ` +\n `${dy5D.rank}.`);\n util.assert(\n input5D.rank === 5,\n () => `Error in avgPool3dBackprop: input must be rank 5 but got rank ` +\n `${input5D.rank}.`);\n if (dilations == null) {\n dilations = [1, 1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in avgPool3dBackprop: Either strides or dilations ' +\n `must be 1. Got strides ${strides} and dilations '${dilations}'`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in maxPool3dBackprop: pad must be an integer when ` +\n `using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computePool3DInfo(\n input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode);\n const res = ENGINE.runKernelFunc(\n backend => backend.avgPool3dBackprop(dy5D, input5D, convInfo),\n {dy5D, input5D});\n if (reshapedTo5D) {\n return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as\n T;\n }\n\n return res as T;\n}\n\n/**\n * Computes the 3D max pooling.\n *\n * ```js\n * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);\n * const result = tf.maxPool3d(x, 2, 1, 'valid');\n * result.print();\n * ```\n *\n * @param x The input tensor, of rank 5 or rank 4 of shape\n * `[batch, depth, height, width, inChannels]`.\n * @param filterSize The filter size:\n * `[filterDepth, filterHeight, filterWidth]`.\n * If `filterSize` is a single number,\n * then `filterDepth == filterHeight == filterWidth`.\n * @param strides The strides of the pooling:\n * `[strideDepth, strideHeight, strideWidth]`.\n * If `strides` is a single number,\n * then `strideDepth == strideHeight == strideWidth`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1*1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n * @param dataFormat An optional string from: \"NDHWC\", \"NCDHW\". Defaults to\n * \"NDHWC\". Specify the data format of the input and output data. With the\n * default format \"NDHWC\", the data is stored in the order of: [batch,\n * depth, height, width, channels]. Only \"NDHWC\" is currently supported.\n * @param dilations The dilation rates:\n * `[dilationDepth, dilationHeight, dilationWidth]`\n * in which we sample input values across the depth, height and width\n * dimensions in dilated pooling.\n * Defaults to `[1, 1, 1]`. If `dilations` is a single number,\n * then `dilationDepth == dilationHeight == dilationWidth`.\n * If it is greater than 1, then all values of `strides` must be 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction maxPool3d_(\n x: T|TensorLike, filterSize: [number, number, number]|number,\n strides: [number, number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil',\n dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC',\n dilations?: [number, number, number]|number): T {\n const $x = convertToTensor(x, 'x', 'maxPool3d');\n\n let x5D = $x as Tensor5D;\n let reshapedTo5D = false;\n if ($x.rank === 4) {\n reshapedTo5D = true;\n x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]);\n }\n\n if (dilations == null) {\n dilations = [1, 1, 1];\n }\n util.assert(\n x5D.rank === 5,\n () => `Error in maxPool3d: x must be rank 5 but got rank ${x5D.rank}.`);\n util.assert(\n dataFormat === 'NDHWC',\n () => `Error in maxPool3d: Only NDHWC is currently supported, ` +\n `but got dataFormat of ${dataFormat}`);\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in maxPool3d: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in maxPool3d: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computePool3DInfo(\n x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode,\n dataFormat);\n\n const grad = (dy: Tensor5D, saved: Tensor[]) => {\n const [x5D, y] = saved;\n return {\n x: () => maxPool3dBackprop(\n dy, x5D as Tensor5D, y as Tensor5D, filterSize, strides, dilations,\n pad, dimRoundingMode)\n };\n };\n\n const res = ENGINE.runKernelFunc((backend, save) => {\n const y = backend.maxPool3d(x5D, convInfo);\n save([x5D, y]);\n return y;\n }, {x: x5D}, grad);\n if (reshapedTo5D) {\n return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as\n T;\n }\n\n return res as T;\n}\n\n/**\n * Computes the backprop of a 3d max pool.\n *\n * @param dy The dy error, of rank 5 of shape\n * [batchSize, depth, height, width, channels].\n * assumed.\n * @param input The original input image, of rank 5 or rank 4 of shape\n * [batchSize, depth, height, width, channels].\n * @param output The original output image, of rank 5 of shape\n * [batchSize, outDepth, outHeight, outWidth, channels].\n * @param filterSize The filter size:\n * `[filterDepth, filterHeight, filterWidth]`.\n * `filterSize` is a single number,\n * then `filterDepth == filterHeight == filterWidth`.\n * @param strides The strides of the pooling:\n * `[strideDepth, strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param dilations The dilation rates:\n * `[dilationDepth, dilationHeight, dilationWidth]`\n * in which we sample input values across the depth, height and width\n * dimensions in dilated pooling.\n * Defaults to `[1, 1, 1]`. If `dilations` is a single number,\n * then `dilationDepth == dilationHeight == dilationWidth`.\n * If it is greater than 1, then all values of `strides` must be 1.\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n * used in the forward prop of the op.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The\n * rounding mode used when computing output dimensions if pad is a\n * number. If none is provided, it will not round and error if the output\n * is of fractional size.\n */\nfunction maxPool3dBackprop(\n dy: T|TensorLike, input: T|TensorLike, output: T|TensorLike,\n filterSize: [number, number, number]|number,\n strides: [number, number, number]|number,\n dilations: [number, number, number]|number, pad: 'valid'|'same'|number,\n dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n const $dy = convertToTensor(dy, 'dy', 'maxPool3dBackprop');\n const $input = convertToTensor(input, 'input', 'maxPool3dBackprop');\n const $output = convertToTensor(output, 'output', 'maxPool3dBackprop');\n\n let dy5D = $dy as Tensor5D;\n let input5D = $input as Tensor5D;\n let output5D = $output as Tensor5D;\n let reshapedTo5D = false;\n if ($input.rank === 4) {\n reshapedTo5D = true;\n dy5D = $dy.as5D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]);\n input5D = $input.as5D(\n 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]);\n output5D = $output.as5D(\n 1, $output.shape[0], $output.shape[1], $output.shape[2],\n $output.shape[3]);\n }\n\n util.assert(\n dy5D.rank === 5,\n () => `Error in maxPool3dBackprop: dy must be rank 5 but got rank ` +\n `${dy5D.rank}.`);\n util.assert(\n input5D.rank === 5,\n () => `Error in maxPool3dBackprop: input must be rank 5 but got rank ` +\n `${input5D.rank}.`);\n util.assert(\n output5D.rank === 5,\n () => `Error in maxPool3dBackprop: output must be rank 5 but got rank ` +\n `${output5D.rank}.`);\n if (dilations == null) {\n dilations = [1, 1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in maxPool3dBackprop: Either strides or dilations ' +\n `must be 1. Got strides ${strides} and dilations '${dilations}'`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in maxPool3dBackprop: pad must be an integer when ` +\n `using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computePool3DInfo(\n input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode);\n const res = ENGINE.runKernelFunc(\n backend => backend.maxPool3dBackprop(dy5D, input5D, output5D, convInfo),\n {dy5D, input5D});\n\n if (reshapedTo5D) {\n return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]) as\n T;\n }\n\n return res as T;\n}\n\n/**\n * Computes the 2D max pooling of an image with Argmax index.\n * The indices in argmax are flattened, so that a maximum value at position `[b,\n * y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if\n * include_batch_in_index is False; `((b * height + y) * width + x) * channels\n * +c` if include_batch_in_index is True.\n *\n * The indices returned are always in `[0, height) x [0, width)` before\n * flattening.\n *\n * @param x The input tensor, of rank 4 or rank 3 of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param filterSize The filter size: `[filterHeight, filterWidth]`. If\n * `filterSize` is a single number, then `filterHeight == filterWidth`.\n * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If\n * `strides` is a single number, then `strideHeight == strideWidth`.\n * @param dataFormat An optional string from: \"NDHWC\", \"NCDHW\". Defaults to\n * \"NDHWC\". Specify the data format of the input and output data. With the\n * default format \"NDHWC\", the data is stored in the order of: [batch,\n * depth, height, width, channels]. Only \"NDHWC\" is currently supported.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param includeBatchIndex Defaults to False. Whether to include batch\n * dimension in flattened index of argmax.\n */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\n/** @doc {heading: 'Operations', subheading: 'Convolution'} */\nfunction maxPoolWithArgmax_(\n x: T|TensorLike, filterSize: [number, number]|number,\n strides: [number, number]|number,\n pad: 'valid'|'same'|number, includeBatchInIndex = false): NamedTensorMap {\n const $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');\n\n const attrs = {filterSize, strides, pad, includeBatchInIndex};\n\n const result =\n ENGINE.runKernel('MaxPoolWithArgmax', {x: $x}, attrs) as Tensor[];\n\n return {result: result[0], indexes: result[1]};\n}\n\nexport const maxPool = op({maxPool_});\nexport const avgPool = op({avgPool_});\nexport const pool = op({pool_});\nexport const maxPool3d = op({maxPool3d_});\nexport const avgPool3d = op({avgPool3d_});\nexport const maxPoolWithArgmax = op({maxPoolWithArgmax_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {Rank, TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {op} from './operation';\nimport {pad} from './pad';\nimport * as slice_util from './slice_util';\n\n/**\n * Extracts a 1D slice from 1D array starting at coordinates `begin` and is\n * of length `size`. See `slice` for details.\n */\nfunction slice1d_(\n x: Tensor1D|TensorLike, begin: number, size: number): Tensor1D {\n const $x = convertToTensor(x, 'x', 'slice1d');\n util.assert(\n $x.rank === 1,\n () =>\n `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);\n return slice($x, [begin], [size]);\n}\n\n/**\n * Extracts a 2D slice from a 2D array starting at coordinates `begin` and\n * is of size `size`. See `slice` for details.\n */\nfunction slice2d_(\n x: Tensor2D|TensorLike, begin: [number, number],\n size: [number, number]): Tensor2D {\n const $x = convertToTensor(x, 'x', 'slice2d');\n util.assert(\n $x.rank === 2,\n () =>\n `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);\n return slice($x, begin, size);\n}\n\n/**\n * Extracts a 3D slice from a 3D array starting at coordinates `begin` and\n * is of size `size`. See `slice` for details.\n */\nfunction slice3d_(\n x: Tensor3D|TensorLike, begin: [number, number, number],\n size: [number, number, number]): Tensor3D {\n const $x = convertToTensor(x, 'x', 'slice3d');\n util.assert(\n $x.rank === 3,\n () =>\n `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);\n return slice($x, begin, size);\n}\n\n/**\n * Extracts a 4D slice from a 4D array starting at coordinates `begin` and\n * is of size `size`. See `slice` for details.\n */\nfunction slice4d_(\n x: Tensor4D|TensorLike, begin: [number, number, number, number],\n size: [number, number, number, number]): Tensor4D {\n const $x = convertToTensor(x, 'x', 'slice4d');\n util.assert(\n $x.rank === 4,\n () =>\n `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);\n return slice($x, begin, size);\n}\n\n/**\n * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`\n * and is of size `size`.\n *\n * Also available are stricter rank-specific methods with the same signature\n * as this method that assert that `x` is of the given rank:\n * - `tf.slice1d`\n * - `tf.slice2d`\n * - `tf.slice3d`\n * - `tf.slice4d`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n *\n * x.slice([1], [2]).print();\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * x.slice([1, 0], [1, 2]).print();\n * ```\n * @param x The input `tf.Tensor` to slice from.\n * @param begin The coordinates to start the slice from. The length can be\n * less than the rank of x - the rest of the axes will have implicit 0 as\n * start. Can also be a single number, in which case it specifies the\n * first axis.\n * @param size The size of the slice. The length can be less than the rank of\n * x - the rest of the axes will have implicit -1. A value of -1 requests\n * the rest of the dimensions in the axis. Can also be a single number,\n * in which case it specifies the size of the first axis.\n */\n/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */\nfunction slice_>(\n x: T|TensorLike, begin: number|number[], size?: number|number[]): T {\n const $x = convertToTensor(x, 'x', 'slice');\n\n if ($x.rank === 0) {\n throw new Error('Slicing scalar is not possible');\n }\n // The following logic allows for more ergonomic calls.\n let begin_: number[];\n if (typeof begin === 'number') {\n begin_ = [begin, ...new Array($x.rank - 1).fill(0)];\n } else if (begin.length < $x.rank) {\n begin_ = begin.concat(new Array($x.rank - begin.length).fill(0));\n } else {\n begin_ = begin.slice();\n }\n begin_.forEach(d => {\n util.assert(\n d !== -1, () => 'slice() does not support negative begin indexing.');\n });\n let size_: number[];\n if (size == null) {\n size_ = new Array($x.rank).fill(-1);\n } else if (typeof size === 'number') {\n size_ = [size, ...new Array($x.rank - 1).fill(-1)];\n } else if (size.length < $x.rank) {\n size_ = size.concat(new Array($x.rank - size.length).fill(-1));\n } else {\n size_ = size;\n }\n size_ = size_.map((d, i) => {\n if (d >= 0) {\n return d;\n } else {\n util.assert(\n d === -1,\n () => `Negative size values should be exactly -1 but got ` +\n `${d} for the slice() size at index ${i}.`);\n return $x.shape[i] - begin_[i];\n }\n });\n slice_util.assertParamsValid($x, begin_, size_);\n const inputShape = $x.shape;\n const grad = (dy: T) => {\n // Create an Nx2 padding where the first column represents how many\n // zeros are prepended (at start) for each dimension, and the second\n // column indicates how many zeros are appended (at end).\n\n // The number of zeros to append is the shape of the input\n // elementwise-subtracted by both the begin vector and sizes vector.\n const paddings: Array<[number, number]> = [];\n for (let i = 0; i < dy.rank; i++) {\n paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);\n }\n return {x: () => pad(dy, paddings)};\n };\n const attrs = {begin: begin_, size: size_};\n return ENGINE.runKernelFunc(\n backend => backend.slice($x, begin_, size_), {x: $x}, grad, 'Slice',\n attrs);\n}\n\nexport const slice = op({slice_});\nexport const slice1d = op({slice1d_});\nexport const slice2d = op({slice2d_});\nexport const slice3d = op({slice3d_});\nexport const slice4d = op({slice4d_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {customGrad} from '../gradients';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport * as axis_util from './axis_util';\nimport {op} from './operation';\nimport {ones, scalar, zerosLike} from './tensor_ops';\n\n/**\n * Computes the log(sum(exp(elements across the reduction dimensions)).\n *\n * Reduces the input along the dimensions given in `axis`. Unless `keepDims`\n * is true, the rank of the array is reduced by 1 for each entry in `axis`.\n * If `keepDims` is true, the reduced dimensions are retained with length 1.\n * If `axis` has no entries, all dimensions are reduced, and an array with a\n * single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.logSumExp().print(); // or tf.logSumExp(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)\n * ```\n * @param x The input tensor.\n * @param axis The dimension(s) to reduce. If null (the default),\n * reduces all dimensions.\n * @param keepDims If true, retains reduced dimensions with length\n * of 1. Defaults to false.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction logSumExp_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n const $x = convertToTensor(x, 'x', 'logSumExp');\n\n const axes = util.parseAxisParam(axis, $x.shape);\n const xMax = $x.max(axes, true /* keepDims */);\n const a = $x.sub(xMax);\n const b = a.exp();\n const c = b.sum(axes);\n const d = c.log();\n const res = xMax.reshape(d.shape).add(d);\n\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(res.shape, axes);\n return res.reshape(newShape) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the sum of elements across dimensions of a `tf.Tensor`.\n *\n * Reduces the input along the dimensions given in `axes`. Unless `keepDims`\n * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in\n * `axes`. If `keepDims` is true, the reduced dimensions are retained with\n * length 1. If axes has no entries, all dimensions are reduced, and a\n * `tf.Tensor` with a single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.sum().print(); // or tf.sum(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.sum(axis).print(); // or tf.sum(x, axis)\n * ```\n *\n * @param x The input tensor to compute the sum over. If the dtype is `bool`\n * it will be converted to `int32` and the output dtype will be `int32`.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction sum_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n let $x = convertToTensor(x, 'x', 'sum');\n\n if ($x.dtype === 'bool') {\n $x = $x.toInt();\n }\n const axes = util.parseAxisParam(axis, $x.shape);\n\n // Use a custom gradient to bypass 2 gradient backprops since sum is used\n // extremely often.\n const customOp = customGrad((x: Tensor) => {\n const permutation = axis_util.getAxesPermutation(axes, x.rank);\n let reductionAxes = axes;\n let permutedX = x;\n if (permutation != null) {\n permutedX = x.transpose(permutation);\n reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, x.rank);\n }\n\n const gradFunc = (dy: Tensor) => {\n const expandedDyShape = x.shape.slice();\n axes.forEach(axis => {\n expandedDyShape[axis] = 1;\n });\n const expandedDy = dy.reshape(expandedDyShape);\n const derX = expandedDy.mul(ones(x.shape, 'float32'));\n return derX;\n };\n\n const gradInputs = (dy: Tensor) => {\n return {x: () => gradFunc(dy)};\n };\n\n const attrs = {axes: reductionAxes};\n let value = ENGINE.runKernelFunc(\n backend => backend.sum(permutedX, reductionAxes), {x: permutedX},\n gradInputs, 'Sum', attrs);\n\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(value.shape, axes);\n value = value.reshape(newShape);\n }\n\n return {value, gradFunc};\n });\n\n return customOp($x) as T;\n}\n\n/**\n * Computes the product of elements across dimensions of a `tf.Tensor`.\n *\n * Reduces the input along the dimensions given in `axes`. Unless `keepDims`\n * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in\n * `axes`. If `keepDims` is true, the reduced dimensions are retained with\n * length 1. If `axes` has no entries, all dimensions are reduced, and a\n * `tf.Tensor` with a single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.prod().print(); // or tf.prod(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.prod(axis).print(); // or tf.prod(x, axis)\n * ```\n *\n * @param x The input tensor to compute the product over. If the dtype is `bool`\n * it will be converted to `int32` and the output dtype will be `int32`.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction prod_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n let $x = convertToTensor(x, 'x', 'prod');\n\n if ($x.dtype === 'bool') {\n $x = $x.toInt();\n }\n const axes = util.parseAxisParam(axis, $x.shape);\n\n const permutation = axis_util.getAxesPermutation(axes, $x.rank);\n let reductionAxes = axes;\n let permutedX = $x;\n if (permutation != null) {\n permutedX = $x.transpose(permutation);\n reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, $x.rank);\n }\n let value = ENGINE.runKernelFunc(\n backend => backend.prod(permutedX, reductionAxes), {permutedX});\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(value.shape, axes);\n value = value.reshape(newShape);\n }\n\n return value as T;\n}\n/**\n * Computes the mean of elements across dimensions of a `tf.Tensor`.\n *\n * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is\n * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.\n * If `keepDims` is true, the reduced dimensions are retained with length 1.\n * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with\n * a single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.mean().print(); // or tf.mean(a)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.mean(axis).print(); // or tf.mean(x, axis)\n * ```\n *\n * @param x The input tensor.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction mean_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n const $x = convertToTensor(x, 'x', 'mean');\n\n const axes = util.parseAxisParam(axis, $x.shape);\n const shapes = axis_util.computeOutAndReduceShapes($x.shape, axes);\n const reduceShape = shapes[1];\n const reduceSize = util.sizeFromShape(reduceShape);\n\n // Use a custom gradient to bypass 2 gradient backprops since mean is used\n // extremely often.\n const customOp = customGrad((x: Tensor) => {\n const reduceSizeScalar = scalar(reduceSize);\n // Cast if needed.\n const xReduce =\n reduceSizeScalar.dtype === x.dtype ? x : x.cast(reduceSizeScalar.dtype);\n const res = xReduce.div(reduceSizeScalar);\n const value = res.sum(axis, keepDims);\n\n const gradFunc = (dy: Tensor) => {\n const expandedDyShape = x.shape.slice();\n axes.forEach(axis => {\n expandedDyShape[axis] = 1;\n });\n const expandedDy = dy.reshape(expandedDyShape);\n const derX = expandedDy.mul(ones(x.shape, 'float32')).div(reduceSize);\n return derX;\n };\n return {value, gradFunc};\n });\n\n return customOp($x) as T;\n}\n\n/**\n * Gradient helper function for the min and max operations.\n */\nfunction gradForMinAndMax(\n dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) {\n if (y.rank < xOrig.rank) {\n y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T;\n }\n if (dy.rank < xOrig.rank) {\n dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T;\n }\n return {\n x: () => {\n const dx = dy.mul(xOrig.equal(y).cast(dy.dtype));\n return permutedAxes == null ? dx : dx.transpose(permutedAxes);\n }\n };\n}\n\n/**\n * Computes the minimum value from the input.\n *\n * Reduces the input along the dimensions given in `axes`. Unless `keepDims`\n * is true, the rank of the array is reduced by 1 for each entry in `axes`.\n * If `keepDims` is true, the reduced dimensions are retained with length 1.\n * If `axes` has no entries, all dimensions are reduced, and an array with a\n * single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.min().print(); // or tf.min(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.min(axis).print(); // or tf.min(x, axis)\n * ```\n *\n * @param x The input Tensor.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction min_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n let $x = convertToTensor(x, 'x', 'min');\n const xOrig = $x;\n\n const origAxes = util.parseAxisParam(axis, $x.shape);\n let axes = origAxes;\n const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);\n if (permutedAxes != null) {\n $x = $x.transpose(permutedAxes);\n axes = axis_util.getInnerMostAxes(axes.length, $x.rank);\n }\n\n const grad = (dy: T, saved: Tensor[]) =>\n gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes);\n\n const inputsToSave = [$x];\n const outputsToSave: boolean[] = [true];\n let res = ENGINE.runKernelFunc((backend, save) => {\n const y = backend.min($x, axes);\n save([xOrig, y]);\n return y as T;\n }, {x: $x}, grad, 'Min', {axes}, inputsToSave, outputsToSave);\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes);\n res = res.reshape(newShape) as T;\n }\n return res;\n}\n\n/**\n * Computes the maximum of elements across dimensions of a `tf.Tensor`.\n *\n * Reduces the input along the dimensions given in `axes`. Unless `keepDims`\n * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in\n * `axes`. If `keepDims` is true, the reduced dimensions are retained with\n * length 1. If `axes` has no entries, all dimensions are reduced, and an\n * `tf.Tensor` with a single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.max().print(); // or tf.max(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n *\n * const axis = 1;\n * x.max(axis).print(); // or tf.max(x, axis)\n * ```\n *\n * @param x The input tensor.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction max_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n let $x = convertToTensor(x, 'x', 'max');\n const xOrig = $x;\n\n const origAxes = util.parseAxisParam(axis, $x.shape);\n let axes = origAxes;\n const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);\n if (permutedAxes != null) {\n $x = $x.transpose(permutedAxes);\n axes = axis_util.getInnerMostAxes(axes.length, $x.rank);\n }\n\n const grad = (dy: T, saved: Tensor[]) =>\n gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes);\n\n const inputsToSave = [$x];\n const outputsToSave: boolean[] = [true];\n let res = ENGINE.runKernelFunc((backend, save) => {\n const y = backend.max($x, axes);\n save([xOrig, y]);\n return y;\n }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave);\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes);\n res = res.reshape(newShape) as T;\n }\n return res as T;\n}\n\n/**\n * Returns the indices of the minimum values along an `axis`.\n *\n * The result has the same shape as `input` with the dimension along `axis`\n * removed.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.argMin().print(); // or tf.argMin(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);\n *\n * const axis = 1;\n * x.argMin(axis).print(); // or tf.argMin(x, axis)\n * ```\n *\n * @param x The input tensor.\n * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).\n *\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction argMin_(x: Tensor|TensorLike, axis = 0): T {\n let $x = convertToTensor(x, 'x', 'argMin');\n\n if (axis == null) {\n axis = 0;\n }\n let axes = util.parseAxisParam(axis, $x.shape);\n const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);\n if (permutedAxes != null) {\n $x = $x.transpose(permutedAxes);\n axes = axis_util.getInnerMostAxes(axes.length, $x.rank);\n }\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {$x: () => zerosLike($x)};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.argMin($x, axes[0]);\n save([$x]);\n return res;\n }, {$x}, grad) as T;\n}\n\n/**\n * Returns the indices of the maximum values along an `axis`.\n *\n * The result has the same shape as `input` with the dimension along `axis`\n * removed.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3]);\n *\n * x.argMax().print(); // or tf.argMax(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);\n *\n * const axis = 1;\n * x.argMax(axis).print(); // or tf.argMax(x, axis)\n * ```\n *\n * @param x The input tensor.\n * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction argMax_(x: Tensor|TensorLike, axis = 0): T {\n let $x = convertToTensor(x, 'x', 'argMax');\n\n if (axis == null) {\n axis = 0;\n }\n let axes = util.parseAxisParam(axis, $x.shape);\n const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);\n if (permutedAxes != null) {\n $x = $x.transpose(permutedAxes);\n axes = axis_util.getInnerMostAxes(axes.length, $x.rank);\n }\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => zerosLike($x)};\n };\n const attrs = {axis: axes[0]};\n const inputsToSave = [$x];\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.argMax($x, axes[0]);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'ArgMax', attrs, inputsToSave) as T;\n}\n\n/**\n * Computes the logical and of elements across dimensions of a `tf.Tensor`.\n *\n * Reduces the input along the dimensions given in `axes`. Unless `keepDims`\n * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in\n * `axes`. If `keepDims` is true, the reduced dimensions are retained with\n * length 1. If `axes` has no entries, all dimensions are reduced, and an\n * `tf.Tensor` with a single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 1, 1], 'bool');\n *\n * x.all().print(); // or tf.all(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');\n *\n * const axis = 1;\n * x.all(axis).print(); // or tf.all(x, axis)\n * ```\n *\n * @param x The input tensor. Must be of dtype bool.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction all_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n let $x = convertToTensor(x, 'x', 'all', 'bool');\n\n const origAxes = util.parseAxisParam(axis, $x.shape);\n let axes = origAxes;\n const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);\n if (permutedAxes != null) {\n $x = $x.transpose(permutedAxes);\n axes = axis_util.getInnerMostAxes(axes.length, $x.rank);\n }\n const res = ENGINE.runKernelFunc(backend => backend.all($x, axes), {$x});\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes);\n return res.reshape(newShape) as T;\n }\n return res as T;\n}\n\n/**\n * Computes the logical or of elements across dimensions of a `tf.Tensor`.\n *\n * Reduces the input along the dimensions given in `axes`. Unless `keepDims`\n * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in\n * `axes`. If `keepDims` is true, the reduced dimensions are retained with\n * length 1. If `axes` has no entries, all dimensions are reduced, and an\n * `tf.Tensor` with a single element is returned.\n *\n * ```js\n * const x = tf.tensor1d([1, 1, 1], 'bool');\n *\n * x.any().print(); // or tf.any(x)\n * ```\n *\n * ```js\n * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');\n *\n * const axis = 1;\n * x.any(axis).print(); // or tf.any(x, axis)\n * ```\n *\n * @param x The input tensor. Must be of dtype bool.\n * @param axis The dimension(s) to reduce. By default it reduces\n * all dimensions.\n * @param keepDims If true, retains reduced dimensions with size 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Reduction'} */\nfunction any_(\n x: Tensor|TensorLike, axis: number|number[] = null, keepDims = false): T {\n let $x = convertToTensor(x, 'x', 'any', 'bool');\n\n const origAxes = util.parseAxisParam(axis, $x.shape);\n let axes = origAxes;\n const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);\n if (permutedAxes != null) {\n $x = $x.transpose(permutedAxes);\n axes = axis_util.getInnerMostAxes(axes.length, $x.rank);\n }\n const res = ENGINE.runKernelFunc(backend => backend.any($x, axes), {$x});\n if (keepDims) {\n const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes);\n return res.reshape(newShape) as T;\n }\n return res as T;\n}\n\n/**\n * Calculates the mean and variance of `x`. The mean and variance are\n * calculated by aggregating the contents of `x` across `axes`. If `x` is\n * 1-D and `axes = [0]` this is just the mean and variance of a vector.\n *\n * @param x The input tensor.\n * @param axis The dimension(s) along with to compute mean and\n * variance. By default it reduces all dimensions.\n * @param keepDims If true, the moments have the same dimensionality as the\n * input.\n * @return An object with two keys: `mean` and `variance`.\n */\n/** @doc {heading: 'Operations', subheading: 'Normalization'} */\nfunction moments_(\n x: Tensor|TensorLike, axis: number|number[] = null,\n keepDims = false): {mean: Tensor, variance: Tensor} {\n x = convertToTensor(x, 'x', 'moments');\n const axes = util.parseAxisParam(axis, x.shape);\n const mean = x.mean(axes, keepDims);\n let keepDimsShape = mean.shape;\n if (!keepDims) {\n keepDimsShape = axis_util.expandShapeToKeepDim(mean.shape, axes);\n }\n const devSquared = x.toFloat().sub(mean.reshape(keepDimsShape)).square();\n const variance = devSquared.mean(axes, keepDims);\n return {mean, variance};\n}\n\nexport const all = op({all_});\n// tslint:disable-next-line:variable-name\nexport const any = op({any_});\nexport const argMax = op({argMax_});\nexport const argMin = op({argMin_});\nexport const logSumExp = op({logSumExp_});\nexport const max = op({max_});\nexport const mean = op({mean_});\nexport const min = op({min_});\nexport const moments = op({moments_});\nexport const sum = op({sum_});\nexport const prod = op({prod_});\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {maximum} from './binary_ops';\nimport {getReductionAxes} from './broadcast_util';\nimport {where} from './logical_ops';\nimport {op} from './operation';\nimport {SELU_SCALE, SELU_SCALEALPHA} from './selu_util';\nimport {scalar, zerosLike} from './tensor_ops';\n\n/**\n * Computes rectified linear element-wise: `max(x, 0)`.\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 4]);\n *\n * x.relu().print(); // or tf.relu(x)\n * ```\n * @param x The input tensor. If the dtype is `bool`, the output dtype will be\n * `int32'.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction relu_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'relu');\n\n if ($x.dtype === 'bool') {\n return $x.toInt();\n }\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {x: () => dy.mulStrict($x.step().toFloat() as T)};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.relu($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Relu');\n}\n\n/**\n * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 8]);\n *\n * x.relu6().print(); // or tf.relu6(x)\n * ```\n * @param x The input tensor. If the dtype is `bool`, the output dtype will be\n * `int32'.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction relu6_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'relu6');\n\n if ($x.dtype === 'bool') {\n return $x.toInt();\n }\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n const mask = $x.lessEqual(6).mul($x.step());\n return {x: () => dy.mulStrict(mask.toFloat() as T)};\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.relu6($x);\n save([$x]);\n return res;\n }, {x: $x}, grad, 'Relu6');\n}\n\n/**\n * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`.\n *\n * ```js\n * const x = tf.tensor1d([-1, 1, -3, 2]);\n *\n * x.elu().print(); // or tf.elu(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction elu_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'elu');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [y] = saved;\n return {\n $x: () =>\n ENGINE.runKernelFunc(backend => backend.eluDer(dy, y), {dy, y}) as T\n };\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const y = backend.elu($x);\n save([y]);\n return y;\n }, {$x}, grad);\n}\n\n/**\n * Computes scaled exponential linear element-wise.\n *\n * `x < 0 ? scale * alpha * (exp(x) - 1) : x`\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 4]);\n *\n * x.selu().print(); // or tf.selu(x)\n * ```\n * @param x The input tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction selu_(x: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'selu');\n\n const grad = (dy: T, saved: Tensor[]) => {\n const [$x] = saved;\n return {\n $x: () => {\n const mask = $x.greater(scalar(0));\n\n const scaleAlpha = scalar(SELU_SCALEALPHA);\n const scale = scalar(SELU_SCALE);\n\n const greaterThanZeroDer = dy.mul(scale);\n const lessEqualZeroDer = dy.mul(scaleAlpha).mul($x.toFloat().exp());\n\n return where(mask, greaterThanZeroDer, lessEqualZeroDer) as T;\n }\n };\n };\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.selu($x);\n save([$x]);\n return res;\n }, {$x}, grad);\n}\n\n/**\n * Computes leaky rectified linear element-wise.\n *\n * See\n * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](\n * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 4]);\n *\n * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)\n * ```\n * @param x The input tensor.\n * @param alpha The scaling factor for negative values, defaults to 0.2.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction leakyRelu_(x: T|TensorLike, alpha = 0.2): T {\n const $x = convertToTensor(x, 'x', 'leakyRelu');\n return maximum(scalar(alpha).mul($x), $x);\n}\n\n/**\n * Computes leaky rectified linear element-wise with parametric alphas.\n *\n * `x < 0 ? alpha * x : f(x) = x`\n *\n * ```js\n * const x = tf.tensor1d([-1, 2, -3, 4]);\n * const alpha = tf.scalar(0.1);\n *\n * x.prelu(alpha).print(); // or tf.prelu(x, alpha)\n * ```\n * @param x The input tensor.\n * @param alpha Scaling factor for negative values.\n */\n/** @doc {heading: 'Operations', subheading: 'Basic math'} */\nfunction prelu_(x: T|TensorLike, alpha: T|TensorLike): T {\n const $x = convertToTensor(x, 'x', 'prelu');\n const $alpha = convertToTensor(alpha, 'alpha', 'prelu');\n\n const grad = (dy: Tensor, saved: Tensor[]) => {\n const [$x, $alpha] = saved;\n const mask = $x.greater(0);\n\n return {\n x: () => where(mask, dy, dy.mul($alpha)) as T,\n alpha: () => {\n let res = where(mask, zerosLike(dy), dy.mul($x));\n const reduceAxes = getReductionAxes($alpha.shape, dy.shape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape($alpha.shape) as T;\n }\n };\n };\n\n return ENGINE.runKernelFunc((backend, save) => {\n const res = backend.prelu($x, $alpha);\n save([$x, $alpha]);\n return res;\n }, {x: $x, alpha: $alpha}, grad, 'Prelu') as T;\n}\n\nexport const elu = op({elu_});\nexport const leakyRelu = op({leakyRelu_});\nexport const prelu = op({prelu_});\nexport const relu = op({relu_});\nexport const relu6 = op({relu6_});\nexport const selu = op({selu_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor, Tensor3D, Tensor4D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport {op} from './operation';\n\n/**\n * Normalizes the activation of a local neighborhood across or within\n * channels.\n *\n * @param x The input tensor. The 4-D input tensor is treated as a 3-D array\n * of 1D vectors (along the last dimension), and each vector is\n * normalized independently.\n * @param depthRadius The number of adjacent channels in the 1D normalization\n * window.\n * @param bias A constant bias term for the basis.\n * @param alpha A scale factor, usually positive.\n * @param beta An exponent.\n */\n/** @doc {heading: 'Operations', subheading: 'Normalization'} */\nfunction localResponseNormalization_(\n x: T|TensorLike, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5): T {\n const $x = convertToTensor(x, 'x', 'localResponseNormalization');\n util.assert(\n $x.rank === 4 || $x.rank === 3,\n () => `Error in localResponseNormalization: x must be rank 3 or 4 but got\n rank ${$x.rank}.`);\n util.assert(\n util.isInt(depthRadius),\n () => `Error in localResponseNormalization: depthRadius must be an ` +\n `integer but got depthRadius ${depthRadius}.`);\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n const backward = (dy: Tensor4D, saved: Tensor[]) => {\n const [x4D, y] = saved;\n return {\n x4D: () => ENGINE.runKernelFunc(\n backend => backend.LRNGrad(\n dy, x4D as Tensor4D, y as Tensor4D, depthRadius, bias, alpha,\n beta),\n {})\n };\n };\n const res = ENGINE.runKernelFunc((backend, save) => {\n const y = backend.localResponseNormalization4D(\n x4D, depthRadius, bias, alpha, beta);\n save([x4D, y]);\n return y;\n }, {x4D}, backward);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n } else {\n return res as T;\n }\n}\n\nexport const localResponseNormalization = op({localResponseNormalization_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {parseAxisParam} from '../util';\n\nimport * as axis_util from './axis_util';\nimport {op} from './operation';\nimport {scalar} from './tensor_ops';\n\n/**\n * Computes the norm of scalar, vectors, and matrices.\n * This function can compute several different vector norms (the 1-norm, the\n * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)\n * and matrix norms (Frobenius, 1-norm, and inf-norm).\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n *\n * x.norm().print(); // or tf.norm(x)\n * ```\n *\n * @param x The input array.\n * @param ord Optional. Order of the norm. Supported norm types are\n * following:\n *\n * | ord | norm for matrices | norm for vectors\n * |------------|---------------------------|---------------------\n * |'euclidean' |Frobenius norm |2-norm\n * |'fro' |Frobenius norm\t |\n * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))\n * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))\n * |1 |max(sum(abs(x), axis=0)) |sum(abs(x))\n * |2 | |sum(abs(x)^2)^1/2*\n *\n * @param axis Optional. If axis is null (the default), the input is\n * considered a vector and a single vector norm is computed over the entire\n * set of values in the Tensor, i.e. norm(x, ord) is equivalent\n * to norm(x.reshape([-1]), ord). If axis is a integer, the input\n * is considered a batch of vectors, and axis determines the axis in x\n * over which to compute vector norms. If axis is a 2-tuple of integer it is\n * considered a batch of matrices and axis determines the axes in NDArray\n * over which to compute a matrix norm.\n * @param keepDims Optional. If true, the norm have the same dimensionality\n * as the input.\n */\n/** @doc {heading: 'Operations', subheading: 'Matrices'} */\nfunction norm_(\n x: Tensor|TensorLike, ord: number|'euclidean'|'fro' = 'euclidean',\n axis: number|number[] = null, keepDims = false): Tensor {\n x = convertToTensor(x, 'x', 'norm');\n\n const norm = normImpl(x, ord, axis);\n let keepDimsShape = norm.shape;\n if (keepDims) {\n const axes = parseAxisParam(axis, x.shape);\n keepDimsShape = axis_util.expandShapeToKeepDim(norm.shape, axes);\n }\n return norm.reshape(keepDimsShape);\n}\n\nfunction normImpl(\n x: Tensor, p: number|string, axis: number|number[] = null): Tensor {\n if (x.rank === 0) {\n return x.abs();\n }\n\n // consider vector when no axis is specified\n if (x.rank !== 1 && axis === null) {\n return normImpl(x.reshape([-1]), p, axis);\n }\n\n // vector\n if (x.rank === 1 || typeof axis === 'number' ||\n Array.isArray(axis) && axis.length === 1) {\n if (p === 1) {\n return x.abs().sum(axis);\n }\n if (p === Infinity) {\n return x.abs().max(axis);\n }\n if (p === -Infinity) {\n return x.abs().min(axis);\n }\n if (p === 'euclidean' || p === 2) {\n // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2\n return x.abs().pow(scalar(2, 'int32')).sum(axis).sqrt();\n }\n\n throw new Error(`Error in norm: invalid ord value: ${p}`);\n }\n\n // matrix (assumption axis[0] < axis[1])\n if (Array.isArray(axis) && axis.length === 2) {\n if (p === 1) {\n return x.abs().sum(axis[0]).max(axis[1] - 1);\n }\n if (p === Infinity) {\n return x.abs().sum(axis[1]).max(axis[0]);\n }\n if (p === -Infinity) {\n return x.abs().sum(axis[1]).min(axis[0]);\n }\n if (p === 'fro' || p === 'euclidean') {\n // norm(x) = sqrt(sum(pow(x, 2)))\n return x.square().sum(axis).sqrt();\n }\n\n throw new Error(`Error in norm: invalid ord value: ${p}`);\n }\n\n throw new Error(`Error in norm: invalid axis: ${axis}`);\n}\n\nexport const norm = op({norm_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Scalar, Tensor1D, Tensor2D} from '../tensor';\nimport {convertToTensor, convertToTensorArray} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {op} from './operation';\n\n/**\n * @docalias (data: Tensor2D, c: Tensor2D, h: Tensor2D): [Tensor2D, Tensor2D]\n */\nexport type LSTMCellFunc = {\n (data: Tensor2D, c: Tensor2D, h: Tensor2D): [Tensor2D, Tensor2D];\n};\n\n/**\n * Computes the next states and outputs of a stack of LSTMCells.\n *\n * Each cell output is used as input to the next cell.\n *\n * Returns `[cellState, cellOutput]`.\n *\n * Derived from tf.contrib.rn.MultiRNNCell.\n *\n * @param lstmCells Array of LSTMCell functions.\n * @param data The input to the cell.\n * @param c Array of previous cell states.\n * @param h Array of previous cell outputs.\n */\n/** @doc {heading: 'Operations', subheading: 'RNN'} */\nfunction multiRNNCell_(\n lstmCells: LSTMCellFunc[], data: Tensor2D|TensorLike,\n c: Array,\n h: Array): [Tensor2D[], Tensor2D[]] {\n const $data = convertToTensor(data, 'data', 'multiRNNCell');\n const $c = convertToTensorArray(c, 'c', 'multiRNNCell');\n const $h = convertToTensorArray(h, 'h', 'multiRNNCell');\n\n let input = $data;\n const newStates = [];\n for (let i = 0; i < lstmCells.length; i++) {\n const output = lstmCells[i](input, $c[i], $h[i]);\n newStates.push(output[0]);\n newStates.push(output[1]);\n input = output[1];\n }\n const newC: Tensor2D[] = [];\n const newH: Tensor2D[] = [];\n for (let i = 0; i < newStates.length; i += 2) {\n newC.push(newStates[i]);\n newH.push(newStates[i + 1]);\n }\n return [newC, newH];\n}\n\n/**\n * Computes the next state and output of a BasicLSTMCell.\n *\n * Returns `[newC, newH]`.\n *\n * Derived from tf.contrib.rnn.BasicLSTMCell.\n *\n * @param forgetBias Forget bias for the cell.\n * @param lstmKernel The weights for the cell.\n * @param lstmBias The bias for the cell.\n * @param data The input to the cell.\n * @param c Previous cell state.\n * @param h Previous cell output.\n */\n/** @doc {heading: 'Operations', subheading: 'RNN'} */\nfunction basicLSTMCell_(\n forgetBias: Scalar|TensorLike, lstmKernel: Tensor2D|TensorLike,\n lstmBias: Tensor1D|TensorLike, data: Tensor2D|TensorLike,\n c: Tensor2D|TensorLike, h: Tensor2D|TensorLike): [Tensor2D, Tensor2D] {\n const $forgetBias =\n convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');\n const $lstmKernel =\n convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');\n const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');\n const $data = convertToTensor(data, 'data', 'basicLSTMCell');\n const $c = convertToTensor(c, 'c', 'basicLSTMCell');\n const $h = convertToTensor(h, 'h', 'basicLSTMCell');\n\n const combined = $data.concat($h, 1);\n const weighted = combined.matMul($lstmKernel);\n const res: Tensor2D = weighted.add($lstmBias);\n\n // i = input_gate, j = new_input, f = forget_gate, o = output_gate\n const batchSize = res.shape[0];\n const sliceCols = res.shape[1] / 4;\n const sliceSize: [number, number] = [batchSize, sliceCols];\n const i = res.slice([0, 0], sliceSize);\n const j = res.slice([0, sliceCols], sliceSize);\n const f = res.slice([0, sliceCols * 2], sliceSize);\n const o = res.slice([0, sliceCols * 3], sliceSize);\n\n const newC = i.sigmoid().mulStrict(j.tanh()).addStrict(\n $c.mulStrict($forgetBias.add(f).sigmoid() as Tensor2D));\n const newH = newC.tanh().mulStrict(o.sigmoid());\n return [newC, newH];\n}\n\nexport const basicLSTMCell = op({basicLSTMCell_});\nexport const multiRNNCell = op({multiRNNCell_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Scalar, Tensor} from '../tensor';\nimport {assertTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\nimport {pow} from './binary_ops';\nimport {op} from './operation';\nimport {scalar} from './tensor_ops';\n\n/**\n * Compute the moving average of a variable.\n *\n * Without zeroDebias, the moving average operation is defined by:\n * `v += delta`\n * where\n * `delta = (1 - decay) * (x - v)`\n *\n * With zeroDebias (default), the `delta` term is scaled to debias the\n * effect of the (assumed) zero-initialization of `v`.\n * `delta /= (1 - decay ^ step)`\n *\n * For more details on the zero-debiasing algorithm, see:\n * https://arxiv.org/abs/1412.6980\n *\n * Note that this function is completely stateless and does not keep track of\n * step count. The step count needs to be maintained by the caller and passed\n * in as `step`.\n *\n * @param v The current moving average value.\n * @param x New input value, must have the same shape and dtype as `v`.\n * @param decay The decay factor. Typical values are 0.95 and 0.99.\n * @param step Step count.\n * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).\n * @returns The new moving average value.\n */\n/** @doc {heading: 'Operations', subheading: 'Moving Average'} */\nfunction movingAverage_(\n v: T|TensorLike, x: T|TensorLike, decay: number|Scalar,\n step?: number|Scalar, zeroDebias = true): T {\n const $v = convertToTensor(v, 'v', 'movingAverage');\n const $x = convertToTensor(x, 'x', 'movingAverage');\n const $decay = convertToTensor(decay, 'decay', 'movingAverage');\n\n assertTypesMatch($v, $x);\n util.assert(\n util.arraysEqual($v.shape, $x.shape), () => 'Shape mismatch in v and x');\n\n const one = scalar(1);\n const oneMinusDecay = one.sub($decay);\n\n let update = $x.sub($v).mul(oneMinusDecay);\n if (zeroDebias) {\n util.assert(\n step != null, () => 'When using zeroDebias: true, step is required.');\n const $step = convertToTensor(step, 'step', 'movingAverage');\n update = update.div(one.sub(pow($decay, $step)));\n }\n return $v.add(update);\n}\n\nexport const movingAverage = op({movingAverage_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\n\nimport {op} from './operation';\nimport {slice} from './slice';\nimport {computeOutShape, maskToAxes, startForAxis, stopForAxis} from './slice_util';\n\n/**\n * Extracts a strided slice of a tensor.\n *\n * Roughly speaking, this op extracts a slice of size (end-begin)/stride from\n * the given input tensor (x). Starting at the location specified by begin the\n * slice continues by adding stride to the index until all dimensions are not\n * less than end. Note that a stride can be negative, which causes a reverse\n * slice.\n *\n * ```js\n * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],\n * [3, 2, 3]);\n * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]\n * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],\n * // [4, 4, 4]]]\n * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],\n * // [3, 3, 3]]]\n * ```\n *\n * @param x The tensor to stride slice.\n * @param begin The coordinates to start the slice from.\n * @param end: The coordinates to end the slice at.\n * @param strides: The size of the slice.\n * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored\n * and the fullest possible range in that dimension is used instead.\n * @param endMask: If the ith bit of endMask is set, end[i] is ignored\n * and the fullest possible range in that dimension is used instead.\n * @param shrinkAxisMask: a bitmask where bit i implies that\n * the ith specification should shrink the dimensionality. begin and end must\n * imply a slice of size 1 in the dimension.\n */\n/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */\nfunction stridedSlice_(\n x: Tensor|TensorLike, begin: number[], end: number[], strides?: number[],\n beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,\n shrinkAxisMask = 0): Tensor {\n if (strides == null) {\n strides = new Array(begin.length);\n }\n if (ellipsisMask !== 0) {\n throw new Error('ellipsis mask is not yet supported');\n }\n let $x = convertToTensor(x, 'x', 'stridedSlice');\n\n // Expand the dims of x based on the newAxisMask.\n const expandAxes = maskToAxes(newAxisMask);\n const newShape = $x.shape.slice();\n expandAxes.forEach(axis => {\n begin[axis] = 0;\n end[axis] = 1;\n newShape.splice(axis, 0, 1);\n });\n $x = $x.reshape(newShape);\n\n // Normalize the start, end and strides.\n for (let axis = 0; axis < $x.rank; axis++) {\n begin[axis] = startForAxis(beginMask, begin, strides, $x.shape, axis);\n end[axis] = stopForAxis(endMask, end, strides, $x.shape, axis);\n strides[axis] = strides[axis] || 1;\n }\n\n const shrinkAxes = maskToAxes(shrinkAxisMask);\n // Adjust the ends based on the shrink mask.\n shrinkAxes.forEach(axis => {\n end[axis] = begin[axis] + 1;\n strides[axis] = 1;\n });\n\n // Figure out the output shape.\n const size = computeOutShape(begin, end, strides);\n // Remove the axes based on shrinkMask.\n const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1);\n\n const nonStrided = strides.every(v => v === 1);\n if (nonStrided) {\n return slice($x, begin, size).reshape(outShape);\n }\n const res = ENGINE.runKernelFunc(\n backend => backend.stridedSlice($x, begin, end, strides), {$x});\n return res.reshape(outShape);\n}\n\nexport const stridedSlice = op({stridedSlice_});\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {NumericTensor, Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {op} from './operation';\n\n/**\n * Finds the values and indices of the `k` largest entries along the last\n * dimension.\n *\n * If the input is a vector (rank=1), finds the k largest entries in the vector\n * and outputs their values and indices as vectors. Thus values[j] is the j-th\n * largest entry in input, and its index is indices[j].\n * For higher rank inputs, computes the top k entries along the last dimension.\n *\n * If two elements are equal, the lower-index element appears first.\n *\n * ```js\n * const a = tf.tensor2d([[1, 5], [4, 3]]);\n * const {values, indices} = tf.topk(a);\n * values.print();\n * indices.print();\n * ```\n * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.\n * @param k Number of top elements to look for along the last dimension.\n * @param sorted If true, the resulting `k` elements will be sorted by the\n * values in descending order.\n */\n/** @doc {heading: 'Operations', subheading: 'Evaluation'} */\nfunction topk_(\n x: T|TensorLike, k = 1, sorted = true): {values: T, indices: T} {\n const $x = convertToTensor(x, 'x', 'topk');\n if ($x.rank === 0) {\n throw new Error('topk() expects the input to be of rank 1 or higher');\n }\n const lastDim = $x.shape[$x.shape.length - 1];\n if (k > lastDim) {\n throw new Error(\n `'k' passed to topk() must be <= the last dimension (${lastDim}) ` +\n `but got ${k}`);\n }\n\n const [values, indices] =\n ENGINE.runKernelFunc(b => b.topk($x as NumericTensor, k, sorted), {$x});\n return {values, indices} as {values: T, indices: T};\n}\n\nexport const topk = op({topk_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {Rank, ShapeMap, TensorLike} from '../types';\nimport {op} from './operation';\nimport * as scatter_nd_util from './scatter_nd_util';\n\n/**\n * Creates a new tensor by applying sparse updates to individual\n * values or slices within a zero tensor of the given shape tensor according to\n * indices. This operator is the inverse of the `tf.gatherND` operator which\n * extracts values or slices from a given tensor.\n *\n * ```js\n * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');\n * const updates = tf.tensor1d([9, 10, 11, 12]);\n * const shape = [8];\n * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]\n * ```\n *\n * @param indices The tensor contains the indices into the output tensor.\n * @param updates The tensor contains the value for the indices.\n * @param shape: The shape of the output tensor.\n */\n/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */\nfunction scatterND_(\n indices: Tensor|TensorLike, updates: Tensor|TensorLike,\n shape: ShapeMap[R]): Tensor {\n const $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');\n const $updates = convertToTensor(updates, 'updates', 'scatterND');\n scatter_nd_util.validateInput($updates, $indices, shape);\n\n return ENGINE.runKernelFunc(\n backend => backend.scatterND($indices, $updates, shape),\n {indices: $indices, updates: $updates}, null /* backward */, 'ScatterNd',\n {shape});\n}\n\nexport const scatterND = op({scatterND_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {complex, imag, real} from '../ops/complex_ops';\nimport {op} from '../ops/operation';\nimport {Tensor, Tensor2D} from '../tensor';\nimport {assert} from '../util';\nimport {scalar, zeros} from './tensor_ops';\n\n/**\n * Fast Fourier transform.\n *\n * Computes the 1-dimensional discrete Fourier transform over the inner-most\n * dimension of input.\n *\n * ```js\n * const real = tf.tensor1d([1, 2, 3]);\n * const imag = tf.tensor1d([1, 2, 3]);\n * const x = tf.complex(real, imag);\n *\n * x.fft().print(); // tf.spectral.fft(x).print();\n * ```\n * @param input The complex input to compute an fft over.\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}\n */\nfunction fft_(input: Tensor): Tensor {\n assert(\n input.dtype === 'complex64',\n () => `The dtype for tf.spectral.fft() must be complex64 ` +\n `but got ${input.dtype}.`);\n\n // Collapse all outer dimensions to a single batch dimension.\n const innerDimensionSize = input.shape[input.shape.length - 1];\n const batch = input.size / innerDimensionSize;\n const input2D = input.as2D(batch, innerDimensionSize);\n\n const ret = ENGINE.runKernelFunc(backend => backend.fft(input2D), {input});\n\n return ret.reshape(input.shape);\n}\n\n/**\n * Inverse fast Fourier transform.\n *\n * Computes the inverse 1-dimensional discrete Fourier transform over the\n * inner-most dimension of input.\n *\n * ```js\n * const real = tf.tensor1d([1, 2, 3]);\n * const imag = tf.tensor1d([1, 2, 3]);\n * const x = tf.complex(real, imag);\n *\n * x.ifft().print(); // tf.spectral.ifft(x).print();\n * ```\n * @param input The complex input to compute an ifft over.\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}\n */\nfunction ifft_(input: Tensor): Tensor {\n assert(\n input.dtype === 'complex64',\n () => `The dtype for tf.spectral.ifft() must be complex64 ` +\n `but got ${input.dtype}.`);\n\n // Collapse all outer dimensions to a single batch dimension.\n const innerDimensionSize = input.shape[input.shape.length - 1];\n const batch = input.size / innerDimensionSize;\n const input2D = input.as2D(batch, innerDimensionSize);\n\n const ret = ENGINE.runKernelFunc(backend => backend.ifft(input2D), {input});\n\n return ret.reshape(input.shape);\n}\n\n/**\n * Real value input fast Fourier transform.\n *\n * Computes the 1-dimensional discrete Fourier transform over the\n * inner-most dimension of the real input.\n *\n * ```js\n * const real = tf.tensor1d([1, 2, 3]);\n *\n * real.rfft().print();\n * ```\n * @param input The real value input to compute an rfft over.\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}\n */\nfunction rfft_(input: Tensor, fftLength?: number): Tensor {\n assert(\n input.dtype === 'float32',\n () => `The dtype for rfft() must be real value but got ${input.dtype}`);\n\n let innerDimensionSize = input.shape[input.shape.length - 1];\n const batch = input.size / innerDimensionSize;\n\n let adjustedInput: Tensor;\n if (fftLength != null && fftLength < innerDimensionSize) {\n // Need to crop\n const begin = input.shape.map(v => 0);\n const size = input.shape.map(v => v);\n size[input.shape.length - 1] = fftLength;\n adjustedInput = input.slice(begin, size);\n innerDimensionSize = fftLength;\n } else if (fftLength != null && fftLength > innerDimensionSize) {\n // Need to pad with zeros\n const zerosShape = input.shape.map(v => v);\n zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;\n adjustedInput = input.concat(zeros(zerosShape), input.shape.length - 1);\n innerDimensionSize = fftLength;\n } else {\n adjustedInput = input;\n }\n\n // Complement the input with zero imaginary numbers.\n const zerosInput = adjustedInput.zerosLike();\n const complexInput =\n complex(adjustedInput, zerosInput).as2D(batch, innerDimensionSize);\n\n const ret = fft(complexInput);\n\n // Exclude complex conjugations. These conjugations are put symmetrically.\n const half = Math.floor(innerDimensionSize / 2) + 1;\n const realValues = real(ret);\n const imagValues = imag(ret);\n const realComplexConjugate = realValues.split(\n [half, innerDimensionSize - half], realValues.shape.length - 1);\n const imagComplexConjugate = imagValues.split(\n [half, innerDimensionSize - half], imagValues.shape.length - 1);\n\n const outputShape = adjustedInput.shape.slice();\n outputShape[adjustedInput.shape.length - 1] = half;\n\n return complex(realComplexConjugate[0], imagComplexConjugate[0])\n .reshape(outputShape);\n}\n\n/**\n * Inversed real value input fast Fourier transform.\n *\n * Computes the 1-dimensional inversed discrete Fourier transform over the\n * inner-most dimension of the real input.\n *\n * ```js\n * const real = tf.tensor1d([1, 2, 3]);\n * const imag = tf.tensor1d([0, 0, 0]);\n * const x = tf.complex(real, imag);\n *\n * x.irfft().print();\n * ```\n * @param input The real value input to compute an irfft over.\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}\n */\nfunction irfft_(input: Tensor): Tensor {\n const innerDimensionSize = input.shape[input.shape.length - 1];\n const batch = input.size / innerDimensionSize;\n\n if (innerDimensionSize <= 2) {\n const complexInput = input.as2D(batch, innerDimensionSize);\n const ret = ifft(complexInput);\n return real(ret);\n } else {\n // The length of unique components of the DFT of a real-valued signal\n // is 2 * (input_len - 1)\n const outputShape = [batch, 2 * (innerDimensionSize - 1)];\n const realInput = real(input).as2D(batch, innerDimensionSize);\n const imagInput = imag(input).as2D(batch, innerDimensionSize);\n\n const realConjugate =\n realInput.slice([0, 1], [batch, innerDimensionSize - 2]).reverse(1);\n const imagConjugate: Tensor2D =\n imagInput.slice([0, 1], [batch, innerDimensionSize - 2])\n .reverse(1)\n .mul(scalar(-1));\n\n const r = realInput.concat(realConjugate, 1);\n const i = imagInput.concat(imagConjugate, 1);\n const complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]);\n const ret = ifft(complexInput);\n return real(ret);\n }\n}\n\nexport const fft = op({fft_});\nexport const ifft = op({ifft_});\nexport const rfft = op({rfft_});\nexport const irfft = op({irfft_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport * as sparse_to_dense from '../ops/sparse_to_dense_util';\nimport {Scalar, Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {Rank, ScalarLike, ShapeMap, TensorLike} from '../types';\nimport {op} from './operation';\n\n/**\n * Converts a sparse representation into a dense tensor.\n *\n * Builds an array dense with shape outputShape such that:\n *\n * // If sparseIndices is scalar\n * dense[i] = (i == sparseIndices ? sparseValues : defaultValue)\n *\n * // If sparseIndices is a vector, then for each i\n * dense[sparseIndices[i]] = sparseValues[i]\n *\n * // If sparseIndices is an n by d matrix, then for each i in [0, n)\n * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]\n * All other values in dense are set to defaultValue. If sparseValues is a\n * scalar, all sparse indices are set to this single value.\n *\n * If indices are repeated the final value is summed over all values for those\n * indices.\n *\n * ```js\n * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');\n * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');\n * const shape = [8];\n * tf.sparseToDense(indices, values, shape).print();\n * ```\n *\n * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.\n * sparseIndices[i] contains the complete index where sparseValues[i] will be\n * placed.\n * @param sparseValues A 0-D or 1-D Tensor. Values\n * corresponding to each row of sparseIndices, or a scalar value to be used for\n * all sparse indices.\n * @param outputShape Shape of the dense output tensor. the type is inferred.\n * @param defaultValue Scalar. Value to set for indices not specified in\n * sparseIndices. Defaults to zero.\n */\n/** @doc {heading: 'Operations', subheading: 'Normalization'} */\nfunction sparseToDense_(\n sparseIndices: Tensor|TensorLike, sparseValues: Tensor|TensorLike,\n outputShape: ShapeMap[R], defaultValue: Scalar|ScalarLike = 0): Tensor {\n const $sparseIndices =\n convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');\n const $sparseValues =\n convertToTensor(sparseValues, 'sparseValues', 'sparseToDense');\n const $defaultValue = convertToTensor(\n defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);\n\n sparse_to_dense.validateInput(\n $sparseIndices, $sparseValues, outputShape, $defaultValue);\n\n return ENGINE.runKernelFunc(\n backend => backend.sparseToDense(\n $sparseIndices, $sparseValues, outputShape, $defaultValue),\n {$sparseIndices, $sparseValues, $defaultValue});\n}\n\nexport const sparseToDense = op({sparseToDense_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor} from '../tensor';\n\n/**\n * Validate sparseToDense inputs.\n *\n * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.\n * sparseIndices[i] contains the complete index where sparseValues[i] will be\n * placed.\n * @param sparseValues A 0-D or 1-D Tensor. Values\n * corresponding to each row of sparseIndices, or a scalar value to be used for\n * all sparse indices.\n * @param outputShape number[]. Shape of the dense output tensor.\n * @param validateIndices boolean. indice validation is not supported, error\n * will be thrown if it is set.\n */\nexport function validateInput(\n sparseIndices: Tensor, sparseValues: Tensor, outputShape: number[],\n defaultValues: Tensor) {\n if (sparseIndices.dtype !== 'int32') {\n throw new Error(\n 'tf.sparseToDense() expects the indices to be int32 type,' +\n ` but the dtype was ${sparseIndices.dtype}.`);\n }\n if (sparseIndices.rank > 2) {\n throw new Error(\n 'sparseIndices should be a scalar, vector, or matrix,' +\n ` but got shape ${sparseIndices.shape}.`);\n }\n\n const numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;\n const numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;\n\n if (outputShape.length !== numDims) {\n throw new Error(\n 'outputShape has incorrect number of elements:,' +\n ` ${outputShape.length}, should be: ${numDims}.`);\n }\n\n const numValues = sparseValues.size;\n if (!(sparseValues.rank === 0 ||\n sparseValues.rank === 1 && numValues === numElems)) {\n throw new Error(\n 'sparseValues has incorrect shape ' +\n `${sparseValues.shape}, should be [] or [${numElems}]`);\n }\n\n if (sparseValues.dtype !== defaultValues.dtype) {\n throw new Error('sparseValues.dtype must match defaultValues.dtype');\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {op} from './operation';\n\n/**\n * Gather slices from input tensor into a Tensor with shape specified by\n * `indices`.\n *\n * `indices` is an K-dimensional integer tensor, best thought of as a\n * (K-1)-dimensional tensor of indices into input, where each element defines a\n * slice of input:\n * output[\\\\(i_0, ..., i_{K-2}\\\\)] = input[indices[\\\\(i_0, ..., i_{K-2}\\\\)]]\n *\n * Whereas in `tf.gather`, `indices` defines slices into the first dimension of\n * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions\n * of input, where N = indices.shape[-1].\n *\n * The last dimension of indices can be at most the rank of input:\n * indices.shape[-1] <= input.rank\n *\n * The last dimension of `indices` corresponds to elements\n * (if indices.shape[-1] == input.rank) or slices\n * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of\n * input.\n * The output tensor has shape\n * indices.shape[:-1] + input.shape[indices.shape[-1]:]\n *\n * Note that on CPU, if an out of bound index is found, an error is returned. On\n * GPU, if an out of bound index is found, a 0 is stored in the corresponding\n * output value.\n *\n * ```js\n * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');\n * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);\n * tf.gatherND(input, indices).print() // [10, 11]\n * ```\n *\n * @param x The tensor from which to gather values.\n * @param indices Index tensor, must be of type int32.\n */\n/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */\nfunction gatherND_(x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor {\n const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');\n const $x = convertToTensor(x, 'x', 'gatherND');\n return ENGINE.runKernelFunc(\n backend => backend.gatherND($x, $indices), {x: $x, indices: $indices},\n null /* backward */, 'GatherNd');\n}\nexport const gatherND = op({gatherND_});\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {op} from './operation';\n\n/**\n * Returns a diagonal tensor with a given diagonal values.\n *\n * Given a diagonal, this operation returns a tensor with the diagonal and\n * everything else padded with zeros.\n *\n * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor\n * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4]);\n *\n * tf.diag(x).print()\n * ```\n * ```js\n * const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])\n *\n * tf.diag(x).print()\n * ```\n * @param x The input tensor.\n */\nfunction diag_(x: Tensor): Tensor {\n const $x = convertToTensor(x, 'x', 'diag').flatten();\n const outShape = [...x.shape, ...x.shape];\n return ENGINE.runKernelFunc(backend => backend.diag($x), {$x})\n .reshape(outShape);\n}\n\nexport const diag = op({diag_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {getNoiseShape} from './dropout_util';\nimport {op} from './operation';\nimport {randomUniform} from './random_uniform';\n\n/**\n * Computes dropout.\n *\n * ```js\n * const x = tf.tensor1d([1, 2, 2, 1]);\n * const rate = 0.75;\n * const output = tf.dropout(x, rate);\n * output.print();\n * ```\n *\n * @param x A floating point Tensor or TensorLike.\n * @param rate A float in the range [0, 1). The probability that each element\n * of x is discarded.\n * @param noiseShape An array of numbers of type int32, representing the\n * shape for randomly generated keep/drop flags. If the noiseShape has null\n * value, it will be automatically replaced with the x's relative dimension\n * size. Optional.\n * @param seed Used to create random seeds. Optional.\n * @returns A Tensor of the same shape of x.\n */\n/** @doc {heading: 'Operations', subheading: 'Dropout'} */\nfunction dropout_(\n x: Tensor|TensorLike, rate: number, noiseShape?: number[],\n seed?: number|string): Tensor {\n const $x = convertToTensor(x, 'x', 'dropout');\n\n util.assert(\n $x.dtype === 'float32',\n () => `x has to be a floating point tensor since it's going to be ` +\n `scaled, but got a ${$x.dtype} tensor instead.`);\n util.assert(\n rate >= 0 && rate < 1,\n () => `rate must be a float in the range [0, 1), but got ${rate}.`);\n\n if (rate === 0) {\n return x instanceof Tensor ? $x.clone() : $x;\n }\n\n const $noiseShape = getNoiseShape($x, noiseShape);\n const keepProb = 1 - rate;\n const multiplier = randomUniform($noiseShape, 0, 1, 'float32', seed)\n .add(keepProb)\n .floor()\n .div(keepProb);\n\n return $x.mul(multiplier);\n}\n\nexport const dropout = op({dropout_});\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport * as util from '../util';\n\n/**\n * Normalize noise shape based on provided tensor and noise shape.\n *\n * @param x Tensor.\n * @param noiseShape The shape for the randomly generated keep/drop flags, as\n * an array of numbers. Optional.\n * @returns Normalized noise shape.\n */\nexport function getNoiseShape(x: Tensor, noiseShape?: number[]): number[] {\n if (noiseShape == null) {\n return x.shape.slice();\n }\n if (util.arraysEqual(x.shape, noiseShape)) {\n return noiseShape;\n }\n if (x.shape.length === noiseShape.length) {\n const newDimension: number[] = [];\n for (let i = 0; i < x.shape.length; i++) {\n if (noiseShape[i] == null && x.shape[i] != null) {\n newDimension.push(x.shape[i]);\n } else {\n newDimension.push(noiseShape[i]);\n }\n }\n return newDimension;\n }\n\n return noiseShape;\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {op} from '../ops/operation';\nimport {Tensor, Tensor1D} from '../tensor';\n\nimport {mul} from './binary_ops';\nimport {concat} from './concat_split';\nimport {slice} from './slice';\nimport {rfft} from './spectral_ops';\nimport {fill, tensor1d, tensor2d} from './tensor_ops';\n\n/**\n * Generate a Hann window.\n *\n * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows\n *\n * ```js\n * tf.signal.hannWindow(10).print();\n * ```\n * @param The length of window\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}\n */\nfunction hannWindow_(windowLength: number): Tensor1D {\n return cosineWindow(windowLength, 0.5, 0.5);\n}\n\n/**\n * Generate a hamming window.\n *\n * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows\n *\n * ```js\n * tf.signal.hammingWindow(10).print();\n * ```\n * @param The length of window\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}\n */\nfunction hammingWindow_(windowLength: number): Tensor1D {\n return cosineWindow(windowLength, 0.54, 0.46);\n}\n\n/**\n * Expands input into frames of frameLength.\n * Slides a window size with frameStep.\n *\n * ```js\n * tf.signal.frame([1, 2, 3], 2, 1).print();\n * ```\n * @param signal The input tensor to be expanded\n * @param frameLength Length of each frame\n * @param frameStep The frame hop size in samples.\n * @param padEnd Whether to pad the end of signal with padValue.\n * @param padValue An number to use where the input signal does\n * not exist when padEnd is True.\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}\n */\nfunction frame_(\n signal: Tensor1D, frameLength: number, frameStep: number, padEnd = false,\n padValue = 0): Tensor {\n let start = 0;\n const output: Tensor[] = [];\n while (start + frameLength <= signal.size) {\n output.push(slice(signal, start, frameLength));\n start += frameStep;\n }\n\n if (padEnd) {\n while (start < signal.size) {\n const padLen = (start + frameLength) - signal.size;\n const pad = concat(\n [slice(signal, start, frameLength - padLen),\n fill([padLen], padValue)]);\n output.push(pad);\n start += frameStep;\n }\n }\n\n if (output.length === 0) {\n return tensor2d([], [0, frameLength]);\n }\n\n return concat(output).as2D(output.length, frameLength);\n}\n\n/**\n * Computes the Short-time Fourier Transform of signals\n * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform\n *\n * ```js\n * const input = tf.tensor1d([1, 1, 1, 1, 1])\n * tf.signal.stft(input, 3, 1).print();\n * ```\n * @param signal 1-dimensional real value tensor.\n * @param frameLength The window length of samples.\n * @param frameStep The number of samples to step.\n * @param fftLength The size of the FFT to apply.\n * @param windowFn A callable that takes a window length and returns 1-d tensor.\n */\n/**\n * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}\n */\nfunction stft_(\n signal: Tensor1D, frameLength: number, frameStep: number,\n fftLength?: number,\n windowFn: (length: number) => Tensor1D = hannWindow): Tensor {\n if (fftLength == null) {\n fftLength = enclosingPowerOfTwo(frameLength);\n }\n const framedSignal = frame(signal, frameLength, frameStep);\n const windowedSignal = mul(framedSignal, windowFn(frameLength));\n const output: Tensor[] = [];\n for (let i = 0; i < framedSignal.shape[0]; i++) {\n output.push(rfft(windowedSignal.slice([i, 0], [1, frameLength]),\n fftLength));\n }\n return concat(output);\n}\n\nfunction enclosingPowerOfTwo(value: number) {\n // Return 2**N for integer N such that 2**N >= value.\n return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));\n}\n\nfunction cosineWindow(windowLength: number, a: number, b: number): Tensor1D {\n const even = 1 - windowLength % 2;\n const newValues = new Float32Array(windowLength);\n for (let i = 0; i < windowLength; ++i) {\n const cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1);\n newValues[i] = a - b * Math.cos(cosArg);\n }\n return tensor1d(newValues, 'float32');\n}\n\nexport const hannWindow = op({hannWindow_});\nexport const hammingWindow = op({hammingWindow_});\nexport const frame = op({frame_});\nexport const stft = op({stft_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {customGrad} from '../gradients';\nimport {Tensor} from '../tensor';\nimport {GradSaveFunc} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assertShapesMatch} from '../util';\nimport {expandShapeToKeepDim} from './axis_util';\nimport {minimum} from './binary_ops';\nimport {op} from './operation';\nimport {ones, scalar} from './tensor_ops';\n\nexport enum Reduction {\n NONE,\n MEAN,\n SUM,\n SUM_BY_NONZERO_WEIGHTS\n}\n\n/**\n * Computes the weighted loss between two tensors.\n *\n * @param losses Tensor of shape `[batch_size, d1, ... dN]`.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `losses`, and must be broadcastable to `losses` (i.e., all\n * dimensions must be either `1`, or the same as the corresponding\n * `losses` dimension).\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction computeWeightedLoss_(\n losses: T|TensorLike, weights?: Tensor|TensorLike,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');\n }\n\n const weightedLoss = ($weights == null) ? $losses : $losses.mul($weights);\n\n if (reduction === Reduction.NONE) {\n return weightedLoss as O;\n }\n if (reduction === Reduction.SUM) {\n return weightedLoss.sum();\n }\n if (reduction === Reduction.MEAN) {\n if ($weights == null) {\n return weightedLoss.mean();\n } else {\n const broadcastFactor = $losses.size / $weights.size;\n const result = weightedLoss.sum().div($weights.sum());\n return broadcastFactor > 1 ? result.div(scalar(broadcastFactor)) :\n result as O;\n }\n }\n if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) {\n if ($weights == null) {\n return weightedLoss.sum().div(scalar($losses.size));\n } else {\n const broadcastedWeights = $weights.mul(ones($losses.shape));\n\n const numNonZeros =\n broadcastedWeights.notEqual(scalar(0)).sum().toFloat();\n return weightedLoss.sum().div(numNonZeros);\n }\n }\n\n throw Error(`Unknown reduction: ${reduction}`);\n}\n\n/**\n * Computes the absolute difference loss between two tensors.\n *\n * @param labels The ground truth output tensor, same dimensions as\n * 'predictions'.\n * @param predictions The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction absoluteDifference_(\n labels: T|TensorLike, predictions: T|TensorLike,\n weights?: Tensor|TensorLike,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n const $labels = convertToTensor(labels, 'labels', 'absoluteDifference');\n const $predictions =\n convertToTensor(predictions, 'predictions', 'absoluteDifference');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'absoluteDifference');\n }\n assertShapesMatch(\n $labels.shape, $predictions.shape, 'Error in absoluteDifference: ');\n\n const losses = $labels.sub($predictions).abs();\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\n/**\n * Computes the mean squared error between two tensors.\n *\n * @param labels The ground truth output tensor, same dimensions as\n * 'predictions'.\n * @param predictions The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction meanSquaredError_(\n labels: T|TensorLike, predictions: T|TensorLike,\n weights?: Tensor|TensorLike,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n const $labels = convertToTensor(labels, 'labels', 'meanSquaredError');\n const $predictions =\n convertToTensor(predictions, 'predictions', 'meanSquaredError');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'meanSquaredError');\n }\n assertShapesMatch(\n $labels.shape, $predictions.shape, 'Error in meanSquaredError: ');\n\n const losses = $labels.squaredDifference($predictions);\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\n/**\n * Computes the cosine distance loss between two tensors.\n *\n * @param labels The ground truth output tensor, same dimensions as\n * 'predictions'.\n * @param predictions The predicted outputs.\n * @param axis The dimension along which the cosine distance is computed.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction cosineDistance_(\n labels: T|TensorLike, predictions: T|TensorLike, axis: number,\n weights?: Tensor|TensorLike,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n const $labels = convertToTensor(labels, 'labels', 'cosineDistance');\n const $predictions =\n convertToTensor(predictions, 'predictions', 'cosineDistance');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'cosineDistance');\n }\n assertShapesMatch(\n $labels.shape, $predictions.shape, 'Error in cosineDistance: ');\n\n const one = scalar(1);\n const losses = one.sub($labels.mul($predictions).sum(axis, true));\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\n/**\n * Computes the Hinge loss between two tensors.\n *\n * @param labels The ground truth output tensor, same dimensions as\n * 'predictions'.\n * @param predictions The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction hingeLoss_(\n labels: T|TensorLike, predictions: T|TensorLike,\n weights?: Tensor|TensorLike,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n let $labels = convertToTensor(labels, 'labels', 'hingeLoss');\n const $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'hingeLoss');\n }\n assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');\n\n const one = scalar(1);\n // Convert binary labels to (-1, 1)\n $labels = scalar(2).mul($labels).sub(one);\n const losses = one.sub($labels.mul($predictions)).relu();\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\n/**\n * Computes the log loss between two tensors.\n *\n * @param labels The ground truth output tensor, same dimensions as\n * 'predictions'.\n * @param predictions The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param epsilon A small increment to avoid taking log of zero\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction logLoss_(\n labels: T|TensorLike, predictions: T|TensorLike,\n weights?: Tensor|TensorLike, epsilon = 1e-7,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n const $labels = convertToTensor(labels, 'labels', 'logLoss');\n const $predictions = convertToTensor(predictions, 'predictions', 'logLoss');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'logLoss');\n }\n assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');\n\n const one = scalar(1);\n const epsilonScalar = scalar(epsilon);\n const losses = $labels.mul($predictions.add(epsilonScalar).log())\n .neg()\n .sub(one.sub($labels).mul(\n one.sub($predictions).add(epsilonScalar).log()));\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\nfunction sigmoidCrossEntropyWithLogits_(\n labels: T|TensorLike, logits: T|TensorLike): O {\n const $labels =\n convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');\n const $logits =\n convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');\n assertShapesMatch(\n $labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');\n\n /**\n * Implementation Details:\n *\n * For brevity, let `x = logits`, `z = labels`. The logistic loss is\n * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))\n * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))\n * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))\n * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))\n * = (1 - z) * x + log(1 + exp(-x))\n * = x - x * z + log(1 + exp(-x))\n *\n * For x < 0, to avoid overflow in exp(-x), we reformulate the above\n * x - x * z + log(1 + exp(-x))\n * = log(exp(x)) - x * z + log(1 + exp(-x))\n * = - x * z + log(1 + exp(x))\n *\n * Hence, to ensure stability and avoid overflow, the implementation uses\n * this equivalent formulation:\n * max(x, 0) - x * z + log(1 + exp(-abs(x)))\n */\n const maxOutput = $logits.relu();\n const outputXTarget = $logits.mul($labels);\n const sigmoidOutput = $logits.abs().neg().exp().log1p();\n\n return maxOutput.sub(outputXTarget).add(sigmoidOutput);\n}\n\n/**\n * Computes the sigmoid cross entropy loss between two tensors.\n *\n * If labelSmoothing is nonzero, smooth the labels towards 1/2:\n *\n * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)\n * + 0.5 * labelSmoothing\n *\n * @param multiClassLabels The ground truth output tensor of shape\n * [batch_size, num_classes], same dimensions as 'predictions'.\n * @param logits The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param labelSmoothing If greater than 0, then smooth the labels.\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' } */\nfunction sigmoidCrossEntropy_(\n multiClassLabels: T|TensorLike, logits: T|TensorLike,\n weights?: Tensor|TensorLike, labelSmoothing = 0,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n let $multiClassLabels = convertToTensor(\n multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');\n const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');\n }\n assertShapesMatch(\n $multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');\n\n if (labelSmoothing > 0) {\n const labelSmoothingScalar = scalar(labelSmoothing);\n const one = scalar(1);\n const half = scalar(0.5);\n\n $multiClassLabels = $multiClassLabels.mul(one.sub(labelSmoothingScalar))\n .add(half.mul(labelSmoothingScalar));\n }\n const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);\n\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\n/**\n * Computes the huber loss between two tensors.\n *\n * @param labels The ground truth output tensor, same dimensions as\n * 'predictions'.\n * @param predictions The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n * `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n * must be either `1`, or the same as the corresponding `losses`\n * dimension).\n * @param delta Point where huber loss changes from quadratic to linear.\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`.\n */\n/** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */\nfunction huberLoss_(\n labels: T|TensorLike, predictions: T|TensorLike,\n weights?: Tensor|TensorLike, delta = 1.0,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n const $labels = convertToTensor(labels, 'labels', 'huberLoss');\n const $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');\n let $weights: Tensor = null;\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'huberLoss');\n }\n assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');\n\n const deltaScalar = scalar(delta);\n const error = $predictions.sub($labels).abs();\n const quadratic = minimum(error, deltaScalar);\n const linear = error.sub(quadratic);\n\n const losses =\n scalar(0.5).mul(quadratic.square()).add(deltaScalar.mul(linear));\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\n/**\n * Computes softmax cross entropy between logits and labels.\n *\n * Measures the probability error in discrete classification tasks in which\n * the classes are mutually exclusive (each entry is in exactly one class).\n * For example, each CIFAR-10 image is labeled with one and only one label: an\n * image can be a dog or a truck, but not both.\n *\n * `NOTE`: While the classes are mutually exclusive, their probabilities need\n * not be. All that is required is that each row of labels is a valid\n * probability distribution. If they are not, the computation of the gradient\n * will be incorrect.\n *\n * `WARNING`: This op expects unscaled logits, since it performs a softmax on\n * logits internally for efficiency. Do not call this op with the output of\n * softmax, as it will produce incorrect results.\n *\n * logits and labels must have the same shape, e.g. [batch_size, num_classes]\n * and the same dtype.\n * @param labels The labels array.\n * @param logits The logits array.\n * @param dim The dimension softmax would be performed on. Defaults to `-1`\n * which indicates the last dimension.\n */\nfunction softmaxCrossEntropyWithLogits_(\n labels: T, logits: T, dim = -1): O {\n if (dim === -1) {\n dim = logits.rank - 1;\n }\n\n if (dim !== logits.rank - 1) {\n throw Error(\n `Softmax cross entropy along a non-last dimension is not yet ` +\n `supported. Labels / logits was rank ${logits.rank} ` +\n `and dim was ${dim}`);\n }\n // Use a custom gradient for numerical stability.\n const customOp =\n customGrad((labels: Tensor, logits: Tensor, save: GradSaveFunc) => {\n // Reference:\n // 1. http://cs231n.github.io/linear-classify/#softmax\n // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/\n const keepDims = true;\n const lse = logits.logSumExp([dim], keepDims);\n const logResult = logits.toFloat().sub(lse);\n save([labels, logResult]);\n\n const costVector = logResult.mul(labels).neg();\n const value: O = costVector.sum([dim]);\n\n const gradFunc = (dy: O, saved: Tensor[]) => {\n const [labels, logResult] = saved;\n const dyShape = expandShapeToKeepDim(dy.shape, [dim]);\n return [\n dy.reshape(dyShape).mul(labels.toFloat().sub(logResult.exp())),\n dy.reshape(dyShape).mul(logResult.exp().sub(labels.toFloat())),\n ];\n };\n return {value, gradFunc};\n });\n\n return customOp(labels, logits);\n}\n\n/**\n * Computes the softmax cross entropy loss between two tensors.\n *\n * If labelSmoothing is nonzero, smooth the labels towards 1/2:\n *\n * newOnehotLabels = onehotLabels * (1 - labelSmoothing)\n * + labelSmoothing / numClasses\n *\n * @param onehotLabels One hot encoded labels\n * [batch_size, num_classes], same dimensions as 'predictions'.\n * @param logits The predicted outputs.\n * @param weights Tensor whose rank is either 0, or 1, and must be\n * broadcastable to `loss` of shape [batch_size]\n * @param labelSmoothing If greater than 0, then smooth the labels.\n * @param reduction Type of reduction to apply to loss. Should be of type\n * `Reduction`\n */\n/** @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' } */\nfunction softmaxCrossEntropy_(\n onehotLabels: T|TensorLike, logits: T|TensorLike,\n weights?: Tensor|TensorLike, labelSmoothing = 0,\n reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n let $onehotLabels =\n convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');\n const $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');\n let $weights: Tensor = null;\n\n if (weights != null) {\n $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');\n }\n\n assertShapesMatch(\n $onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');\n\n if (labelSmoothing > 0) {\n const labelSmoothingScalar = scalar(labelSmoothing);\n const one = scalar(1);\n const numClasses = scalar($onehotLabels.shape[1]);\n\n $onehotLabels = $onehotLabels.mul(one.sub(labelSmoothingScalar))\n .add(labelSmoothingScalar.div(numClasses));\n }\n\n const losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);\n\n return computeWeightedLoss(losses, $weights, reduction);\n}\n\nexport const absoluteDifference = op({absoluteDifference_});\nexport const computeWeightedLoss = op({computeWeightedLoss_});\nexport const cosineDistance = op({cosineDistance_});\nexport const hingeLoss = op({hingeLoss_});\nexport const huberLoss = op({huberLoss_});\nexport const logLoss = op({logLoss_});\nexport const meanSquaredError = op({meanSquaredError_});\nexport const sigmoidCrossEntropy = op({sigmoidCrossEntropy_});\nexport const softmaxCrossEntropy = op({softmaxCrossEntropy_});\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assert, assertShapesMatch, getTypedArrayFromDType} from '../util';\nimport {tensor} from './tensor_ops';\n\n/**\n * Returns whether the targets are in the top K predictions.\n *\n * ```js\n * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);\n * const targets = tf.tensor1d([2, 0]);\n * const precision = await tf.inTopKAsync(predictions, targets);\n * precision.print();\n * ```\n * @param predictions 2-D or higher `tf.Tensor` with last dimension being\n * at least `k`.\n * @param targets 1-D or higher `tf.Tensor`.\n * @param k Optional Number of top elements to look at for computing precision,\n * default to 1.\n */\n/** @doc {heading: 'Operations', subheading: 'Evaluation'} */\nasync function inTopKAsync_(\n predictions: T|TensorLike, targets: U|TensorLike, k = 1): Promise {\n const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');\n const $targets = convertToTensor(targets, 'targets', 'inTopK');\n\n assert(\n $predictions.rank > 1,\n () => 'inTopK() expects the predictions to be of rank 2 or higher, ' +\n `but got ${$predictions.rank}`);\n assert(\n $predictions.rank - 1 === $targets.rank,\n () => `predictions rank should be 1 larger than ` +\n `targets rank, but got predictions rank ` +\n `${$predictions.rank} and targets rank ${$targets.rank}`);\n assertShapesMatch(\n $predictions.shape.slice(0, $predictions.shape.length - 1),\n $targets.shape,\n `predictions's shape should be align with the targets' shape, ` +\n 'except the last dimension.');\n const lastDim = $predictions.shape[$predictions.shape.length - 1];\n assert(\n k > 0 && k <= lastDim,\n () => `'k' passed to inTopK() must be > 0 && <= the predictions last ` +\n `dimension (${lastDim}), but got ${k}`);\n\n const predictionsVals = await $predictions.data();\n const targetsVals = await $targets.data();\n\n // Reshape predictionsVals into a 2d tensor [batch, lastDim]\n // and look up topK along lastDim.\n const [batch, size] = [predictionsVals.length / lastDim, lastDim];\n const precision = getTypedArrayFromDType('bool', batch);\n\n for (let b = 0; b < batch; b++) {\n const offset = b * size;\n const vals = predictionsVals.subarray(offset, offset + size);\n const valAndInd: Array<{value: number, index: number}> = [];\n for (let i = 0; i < vals.length; i++) {\n valAndInd.push({value: vals[i], index: i});\n }\n valAndInd.sort((a, b) => b.value - a.value);\n\n precision[b] = 0;\n for (let i = 0; i < k; i++) {\n if (valAndInd[i].index === targetsVals[b]) {\n precision[b] = 1;\n break;\n }\n }\n }\n\n if (predictions !== $predictions) {\n $predictions.dispose();\n }\n if (targets !== $targets) {\n $targets.dispose();\n }\n\n // Output precision has the same shape as targets.\n return tensor(precision, $targets.shape, 'bool') as U;\n}\n\nexport const inTopKAsync = inTopKAsync_;\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Linear algebra ops.\n */\n\nimport {ENGINE} from '../engine';\nimport {dispose} from '../globals';\nimport {Tensor, Tensor1D, Tensor2D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport {assert} from '../util';\nimport {squeeze, stack, unstack} from './array_ops';\nimport {sub} from './binary_ops';\nimport {split} from './concat_split';\nimport {eye} from './eye';\nimport {logicalAnd, where} from './logical_ops';\nimport {norm} from './norm';\nimport {op} from './operation';\nimport {sum} from './reduction_ops';\nimport {range, scalar, tensor2d, zeros} from './tensor_ops';\n\n/**\n * Copy a tensor setting everything outside a central band in each innermost\n * matrix to zero.\n *\n * The band part is computed as follows: Assume input has `k` dimensions\n * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where\n * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.\n * The indicator function\n * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`\n * `&& (num_upper < 0 || (n-m) <= num_upper)`\n *\n * ```js\n * const x = tf.tensor2d([[ 0, 1, 2, 3],\n * [-1, 0, 1, 2],\n * [-2, -1, 0, 1],\n * [-3, -2, -1, 0]]);\n * let y = tf.linalg.bandPart(x, 1, -1);\n * y.print(); // [[ 0, 1, 2, 3],\n * // [-1, 0, 1, 2],\n * // [ 0, -1, 0, 1],\n * // [ 0, 0 , -1, 0]]\n * let z = tf.linalg.bandPart(x, 2, 1);\n * z.print(); // [[ 0, 1, 0, 0],\n * // [-1, 0, 1, 0],\n * // [-2, -1, 0, 1],\n * // [ 0, -2, -1, 0]]\n * ```\n *\n * @param x Rank `k` tensor\n * @param numLower Number of subdiagonals to keep.\n * If negative, keep entire lower triangle.\n * @param numUpper Number of subdiagonals to keep.\n * If negative, keep entire upper triangle.\n * @returns Rank `k` tensor of the same shape as input.\n * The extracted banded tensor.\n */\n/**\n * @doc {heading:'Operations',\n * subheading:'Linear Algebra',\n * namespace:'linalg'}\n */\nfunction bandPart_(\n a: T|TensorLike, numLower: number, numUpper: number): T {\n if (numLower % 1 !== 0) {\n throw new Error(\n `bandPart(): numLower must be an integer, got ${numLower}.`);\n }\n if (numUpper % 1 !== 0) {\n throw new Error(\n `bandPart(): numUpper must be an integer, got ${numUpper}.`);\n }\n\n const $a = convertToTensor(a, 'a', 'bandPart');\n\n if ($a.rank < 2) {\n throw new Error(`bandPart(): Rank must be at least 2, got ${$a.rank}.`);\n }\n\n const shape = $a.shape, [M, N] = $a.shape.slice(-2);\n\n if (!(numLower <= M)) {\n throw new Error(\n `bandPart(): numLower (${numLower})` +\n ` must not be greater than the number of rows (${M}).`);\n }\n if (!(numUpper <= N)) {\n throw new Error(\n `bandPart(): numUpper (${numUpper})` +\n ` must not be greater than the number of columns (${N}).`);\n }\n\n if (numLower < 0) {\n numLower = M;\n }\n if (numUpper < 0) {\n numUpper = N;\n }\n\n const i = range(0, M, 1, 'int32').reshape([-1, 1]),\n j = range(0, N, 1, 'int32'), ij = sub(i, j);\n\n const inBand = logicalAnd(\n ij.lessEqual(scalar(+numLower, 'int32')),\n ij.greaterEqual(scalar(-numUpper, 'int32')));\n\n const zero = zeros([M, N], $a.dtype);\n\n return stack(unstack($a.reshape([-1, M, N]))\n .map(mat => where(inBand, mat, zero)))\n .reshape(shape) as T;\n}\n\n/**\n * Gram-Schmidt orthogonalization.\n *\n * ```js\n * const x = tf.tensor2d([[1, 2], [3, 4]]);\n * let y = tf.linalg.gramSchmidt(x);\n * y.print();\n * console.log('Othogonalized:');\n * y.dot(y.transpose()).print(); // should be nearly the identity matrix.\n * console.log('First row direction maintained:');\n * const data = await y.array();\n * console.log(data[0][1] / data[0][0]); // should be nearly 2.\n * ```\n *\n * @param xs The vectors to be orthogonalized, in one of the two following\n * formats:\n * - An Array of `tf.Tensor1D`.\n * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows\n * of `xs`.\n * In each case, all the vectors must have the same length and the length\n * must be greater than or equal to the number of vectors.\n * @returns The orthogonalized and normalized vectors or matrix.\n * Orthogonalization means that the vectors or the rows of the matrix\n * are orthogonal (zero inner products). Normalization means that each\n * vector or each row of the matrix has an L2 norm that equals `1`.\n */\n/**\n * @doc {heading:'Operations',\n * subheading:'Linear Algebra',\n * namespace:'linalg'}\n */\nfunction gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D {\n let inputIsTensor2D: boolean;\n if (Array.isArray(xs)) {\n inputIsTensor2D = false;\n assert(\n xs != null && xs.length > 0,\n () => 'Gram-Schmidt process: input must not be null, undefined, or ' +\n 'empty');\n const dim = xs[0].shape[0];\n for (let i = 1; i < xs.length; ++i) {\n assert(\n xs[i].shape[0] === dim,\n () =>\n 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +\n `(${(xs as Tensor1D[])[i].shape[0]} vs. ${dim})`);\n }\n } else {\n inputIsTensor2D = true;\n xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));\n }\n\n assert(\n xs.length <= xs[0].shape[0],\n () => `Gram-Schmidt: Number of vectors (${\n (xs as Tensor1D[]).length}) exceeds ` +\n `number of dimensions (${(xs as Tensor1D[])[0].shape[0]}).`);\n\n const ys: Tensor1D[] = [];\n const xs1d = xs;\n for (let i = 0; i < xs.length; ++i) {\n ys.push(ENGINE.tidy(() => {\n let x = xs1d[i];\n if (i > 0) {\n for (let j = 0; j < i; ++j) {\n const proj = sum(ys[j].mulStrict(x)).mul(ys[j]);\n x = x.sub(proj);\n }\n }\n return x.div(norm(x, 'euclidean'));\n }));\n }\n\n if (inputIsTensor2D) {\n return stack(ys, 0) as Tensor2D;\n } else {\n return ys;\n }\n}\n\n/**\n * Compute QR decomposition of m-by-n matrix using Householder transformation.\n *\n * Implementation based on\n * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]\n * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)\n *\n * ```js\n * const a = tf.tensor2d([[1, 2], [3, 4]]);\n * let [q, r] = tf.linalg.qr(a);\n * console.log('Q');\n * q.print();\n * console.log('R');\n * r.print();\n * console.log('Orthogonalized');\n * q.dot(q.transpose()).print() // should be nearly the identity matrix.\n * console.log('Reconstructed');\n * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];\n * ```\n *\n * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose\n * it has the shape `[..., M, N]`.\n * @param fullMatrices An optional boolean parameter. Defaults to `false`.\n * If `true`, compute full-sized `Q`. If `false` (the default),\n * compute only the leading N columns of `Q` and `R`.\n * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,\n * i.e., its columns all have unit norm and are mutually orthogonal.\n * If `M >= N`,\n * If `fullMatrices` is `false` (default),\n * - `Q` has a shape of `[..., M, N]`,\n * - `R` has a shape of `[..., N, N]`.\n * If `fullMatrices` is `true` (default),\n * - `Q` has a shape of `[..., M, M]`,\n * - `R` has a shape of `[..., M, N]`.\n * If `M < N`,\n * - `Q` has a shape of `[..., M, M]`,\n * - `R` has a shape of `[..., M, N]`.\n * @throws If the rank of `x` is less than 2.\n */\n/**\n * @doc {heading:'Operations',\n * subheading:'Linear Algebra',\n * namespace:'linalg'}\n */\nfunction qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] {\n if (x.rank < 2) {\n throw new Error(\n `qr() requires input tensor to have a rank >= 2, but got rank ${\n x.rank}`);\n } else if (x.rank === 2) {\n return qr2d(x as Tensor2D, fullMatrices);\n } else {\n // Rank > 2.\n // TODO(cais): Below we split the input into individual 2D tensors,\n // perform QR decomposition on them and then stack the results back\n // together. We should explore whether this can be parallelized.\n const outerDimsProd = x.shape.slice(0, x.shape.length - 2)\n .reduce((value, prev) => value * prev);\n const x2ds = unstack(\n x.reshape([\n outerDimsProd, x.shape[x.shape.length - 2],\n x.shape[x.shape.length - 1]\n ]),\n 0);\n const q2ds: Tensor2D[] = [];\n const r2ds: Tensor2D[] = [];\n x2ds.forEach(x2d => {\n const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices);\n q2ds.push(q2d);\n r2ds.push(r2d);\n });\n const q = stack(q2ds, 0).reshape(x.shape);\n const r = stack(r2ds, 0).reshape(x.shape);\n return [q, r];\n }\n}\n\nfunction qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {\n return ENGINE.tidy(() => {\n if (x.shape.length !== 2) {\n throw new Error(\n `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);\n }\n\n const m = x.shape[0];\n const n = x.shape[1];\n\n let q = eye(m); // Orthogonal transform so far.\n let r = x.clone(); // Transformed matrix so far.\n\n const one2D = tensor2d([[1]], [1, 1]);\n let w: Tensor2D = one2D.clone();\n\n const iters = m >= n ? n : m;\n for (let j = 0; j < iters; ++j) {\n // This tidy within the for-loop ensures we clean up temporary\n // tensors as soon as they are no longer needed.\n const rTemp = r;\n const wTemp = w;\n const qTemp = q;\n [w, r, q] = ENGINE.tidy((): [Tensor2D, Tensor2D, Tensor2D] => {\n // Find H = I - tau * w * w', to put zeros below R(j, j).\n const rjEnd1 = r.slice([j, j], [m - j, 1]);\n const normX = rjEnd1.norm();\n const rjj = r.slice([j, j], [1, 1]);\n\n // The sign() function returns 0 on 0, which causes division by zero.\n const s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]]));\n\n const u1 = rjj.sub(s.mul(normX));\n const wPre = rjEnd1.div(u1);\n if (wPre.shape[0] === 1) {\n w = one2D.clone();\n } else {\n w = one2D.concat(\n wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) as\n Tensor2D,\n 0);\n }\n const tau = s.matMul(u1).div(normX).neg() as Tensor2D;\n\n // -- R := HR, Q := QH.\n const rjEndAll = r.slice([j, 0], [m - j, n]);\n const tauTimesW: Tensor2D = tau.mul(w);\n const wT: Tensor2D = w.transpose();\n if (j === 0) {\n r = rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll)));\n } else {\n const rTimesTau: Tensor2D =\n rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll)));\n r = r.slice([0, 0], [j, n]).concat(rTimesTau, 0);\n }\n const tawTimesWT: Tensor2D = tauTimesW.transpose();\n const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]);\n if (j === 0) {\n q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT));\n } else {\n const qTimesTau: Tensor2D =\n qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT));\n q = q.slice([0, 0], [m, j]).concat(qTimesTau, 1);\n }\n return [w, r, q];\n });\n dispose([rTemp, wTemp, qTemp]);\n }\n\n if (!fullMatrices && m > n) {\n q = q.slice([0, 0], [m, n]);\n r = r.slice([0, 0], [n, n]);\n }\n\n return [q, r];\n }) as [Tensor2D, Tensor2D];\n}\n\nexport const bandPart = op({bandPart_});\nexport const gramSchmidt = op({gramSchmidt_});\nexport const qr = op({qr_});\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../backends/non_max_suppression_impl';\nimport {ENGINE, ForwardFunc} from '../engine';\nimport {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {op} from './operation';\n\n/**\n * Bilinear resize a batch of 3D images to a new shape.\n *\n * @param images The images, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param size The new shape `[newHeight, newWidth]` to resize the\n * images to. Each channel is resized individually.\n * @param alignCorners Defaults to False. If true, rescale\n * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4\n * corners of images and resized images. If false, rescale by\n * `new_height / height`. Treat similarly the width dimension.\n */\n/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */\nfunction resizeBilinear_(\n images: T|TensorLike, size: [number, number], alignCorners = false): T {\n const $images = convertToTensor(images, 'images', 'resizeBilinear');\n util.assert(\n $images.rank === 3 || $images.rank === 4,\n () => `Error in resizeBilinear: x must be rank 3 or 4, but got ` +\n `rank ${$images.rank}.`);\n util.assert(\n size.length === 2,\n () => `Error in resizeBilinear: new shape must 2D, but got shape ` +\n `${size}.`);\n\n let batchImages = $images as Tensor4D;\n let reshapedTo4D = false;\n if ($images.rank === 3) {\n reshapedTo4D = true;\n batchImages =\n $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]);\n }\n\n const [newHeight, newWidth] = size;\n const forward: ForwardFunc = (backend, save) => {\n save([batchImages]);\n return backend.resizeBilinear(\n batchImages, newHeight, newWidth, alignCorners);\n };\n\n const backward = (dy: Tensor4D, saved: Tensor[]) => {\n return {\n x: () => ENGINE.runKernelFunc(\n backend => backend.resizeBilinearBackprop(\n dy, saved[0] as Tensor4D, alignCorners),\n {})\n };\n };\n\n const res = ENGINE.runKernelFunc(\n forward, {x: batchImages}, backward, 'ResizeBilinear',\n {alignCorners, newHeight, newWidth});\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * NearestNeighbor resize a batch of 3D images to a new shape.\n *\n * @param images The images, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.\n * @param size The new shape `[newHeight, newWidth]` to resize the\n * images to. Each channel is resized individually.\n * @param alignCorners Defaults to False. If true, rescale\n * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4\n * corners of images and resized images. If false, rescale by\n * `new_height / height`. Treat similarly the width dimension.\n */\n/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */\nfunction resizeNearestNeighbor_(\n images: T|TensorLike, size: [number, number], alignCorners = false): T {\n const $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');\n util.assert(\n $images.rank === 3 || $images.rank === 4,\n () => `Error in resizeNearestNeighbor: x must be rank 3 or 4, but got ` +\n `rank ${$images.rank}.`);\n util.assert(\n size.length === 2,\n () =>\n `Error in resizeNearestNeighbor: new shape must 2D, but got shape ` +\n `${size}.`);\n util.assert(\n $images.dtype === 'float32' || $images.dtype === 'int32',\n () => '`images` must have `int32` or `float32` as dtype');\n\n let batchImages = $images as Tensor4D;\n let reshapedTo4D = false;\n if ($images.rank === 3) {\n reshapedTo4D = true;\n batchImages =\n $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]);\n }\n const [newHeight, newWidth] = size;\n\n const forward: ForwardFunc = (backend, save) => {\n save([batchImages]);\n return backend.resizeNearestNeighbor(\n batchImages, newHeight, newWidth, alignCorners);\n };\n\n const backward = (dy: Tensor4D, saved: Tensor[]) => {\n return {\n batchImages: () => ENGINE.runKernelFunc(\n backend => backend.resizeNearestNeighborBackprop(\n dy, saved[0] as Tensor4D, alignCorners),\n {})\n };\n };\n\n const res = ENGINE.runKernelFunc(forward, {batchImages}, backward);\n\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\n/**\n * Performs non maximum suppression of bounding boxes based on\n * iou (intersection over union).\n *\n * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is\n * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of\n * the bounding box.\n * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.\n * @param maxOutputSize The maximum number of boxes to be selected.\n * @param iouThreshold A float representing the threshold for deciding whether\n * boxes overlap too much with respect to IOU. Must be between [0, 1].\n * Defaults to 0.5 (50% box overlap).\n * @param scoreThreshold A threshold for deciding when to remove boxes based\n * on score. Defaults to -inf, which means any score is accepted.\n * @return A 1D tensor with the selected box indices.\n */\n/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */\nfunction nonMaxSuppression_(\n boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike,\n maxOutputSize: number, iouThreshold = 0.5,\n scoreThreshold = Number.NEGATIVE_INFINITY): Tensor1D {\n const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');\n const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');\n\n const inputs = nonMaxSuppSanityCheck(\n $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);\n maxOutputSize = inputs.maxOutputSize;\n iouThreshold = inputs.iouThreshold;\n scoreThreshold = inputs.scoreThreshold;\n\n const attrs = {maxOutputSize, iouThreshold, scoreThreshold};\n return ENGINE.runKernelFunc(\n b => b.nonMaxSuppression(\n $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold),\n {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV3',\n attrs);\n}\n\n/** This is the async version of `nonMaxSuppression` */\nasync function nonMaxSuppressionAsync_(\n boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike,\n maxOutputSize: number, iouThreshold = 0.5,\n scoreThreshold = Number.NEGATIVE_INFINITY): Promise {\n const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');\n const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');\n\n const inputs = nonMaxSuppSanityCheck(\n $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);\n maxOutputSize = inputs.maxOutputSize;\n iouThreshold = inputs.iouThreshold;\n scoreThreshold = inputs.scoreThreshold;\n\n const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);\n const boxesVals = boxesAndScores[0];\n const scoresVals = boxesAndScores[1];\n\n const res = nonMaxSuppressionV3(\n boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);\n if ($boxes !== boxes) {\n $boxes.dispose();\n }\n if ($scores !== scores) {\n $scores.dispose();\n }\n return res;\n}\n\n/**\n * Performs non maximum suppression of bounding boxes based on\n * iou (intersection over union).\n *\n * This op also supports a Soft-NMS mode (c.f.\n * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score\n * of other overlapping boxes, therefore favoring different regions of the image\n * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`\n * parameter to be larger than 0.\n *\n * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is\n * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of\n * the bounding box.\n * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.\n * @param maxOutputSize The maximum number of boxes to be selected.\n * @param iouThreshold A float representing the threshold for deciding whether\n * boxes overlap too much with respect to IOU. Must be between [0, 1].\n * Defaults to 0.5 (50% box overlap).\n * @param scoreThreshold A threshold for deciding when to remove boxes based\n * on score. Defaults to -inf, which means any score is accepted.\n * @param softNmsSigma A float representing the sigma parameter for Soft NMS.\n * When sigma is 0, it falls back to nonMaxSuppression.\n * @return A map with the following properties:\n * - selectedIndices: A 1D tensor with the selected box indices.\n * - selectedScores: A 1D tensor with the corresponding scores for each\n * selected box.\n */\n/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */\nfunction nonMaxSuppressionWithScore_(\n boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike,\n maxOutputSize: number, iouThreshold = 0.5,\n scoreThreshold = Number.NEGATIVE_INFINITY,\n softNmsSigma = 0.0): NamedTensorMap {\n const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');\n const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');\n\n const inputs = nonMaxSuppSanityCheck(\n $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold,\n softNmsSigma);\n maxOutputSize = inputs.maxOutputSize;\n iouThreshold = inputs.iouThreshold;\n scoreThreshold = inputs.scoreThreshold;\n softNmsSigma = inputs.softNmsSigma;\n\n const attrs = {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma};\n\n const result = ENGINE.runKernel(\n 'NonMaxSuppressionV5', {boxes: $boxes, scores: $scores},\n attrs) as Tensor[];\n\n return {selectedIndices: result[0], selectedScores: result[1]};\n}\n\n/** This is the async version of `nonMaxSuppressionWithScore` */\nasync function nonMaxSuppressionWithScoreAsync_(\n boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike,\n maxOutputSize: number, iouThreshold = 0.5,\n scoreThreshold = Number.NEGATIVE_INFINITY,\n softNmsSigma = 0.0): Promise {\n const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');\n const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');\n\n const inputs = nonMaxSuppSanityCheck(\n $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold,\n softNmsSigma);\n maxOutputSize = inputs.maxOutputSize;\n iouThreshold = inputs.iouThreshold;\n scoreThreshold = inputs.scoreThreshold;\n softNmsSigma = inputs.softNmsSigma;\n\n const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]);\n const boxesVals = boxesAndScores[0];\n const scoresVals = boxesAndScores[1];\n\n const res = nonMaxSuppressionV5(\n boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold,\n softNmsSigma);\n\n if ($boxes !== boxes) {\n $boxes.dispose();\n }\n if ($scores !== scores) {\n $scores.dispose();\n }\n return res;\n}\n\nfunction nonMaxSuppSanityCheck(\n boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,\n iouThreshold: number, scoreThreshold: number, softNmsSigma?: number): {\n maxOutputSize: number,\n iouThreshold: number,\n scoreThreshold: number,\n softNmsSigma: number\n} {\n if (iouThreshold == null) {\n iouThreshold = 0.5;\n }\n if (scoreThreshold == null) {\n scoreThreshold = Number.NEGATIVE_INFINITY;\n }\n if (softNmsSigma == null) {\n softNmsSigma = 0.0;\n }\n\n const numBoxes = boxes.shape[0];\n maxOutputSize = Math.min(maxOutputSize, numBoxes);\n\n util.assert(\n 0 <= iouThreshold && iouThreshold <= 1,\n () => `iouThreshold must be in [0, 1], but was '${iouThreshold}'`);\n util.assert(\n boxes.rank === 2,\n () => `boxes must be a 2D tensor, but was of rank '${boxes.rank}'`);\n util.assert(\n boxes.shape[1] === 4,\n () =>\n `boxes must have 4 columns, but 2nd dimension was ${boxes.shape[1]}`);\n util.assert(scores.rank === 1, () => 'scores must be a 1D tensor');\n util.assert(\n scores.shape[0] === numBoxes,\n () => `scores has incompatible shape with boxes. Expected ${numBoxes}, ` +\n `but was ${scores.shape[0]}`);\n util.assert(\n 0 <= softNmsSigma && softNmsSigma <= 1,\n () => `softNmsSigma must be in [0, 1], but was '${softNmsSigma}'`);\n return {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma};\n}\n\n/**\n * Extracts crops from the input image tensor and resizes them using bilinear\n * sampling or nearest neighbor sampling (possibly with aspect ratio change)\n * to a common output size specified by crop_size.\n *\n * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,\n * where imageHeight and imageWidth must be positive, specifying the\n * batch of images from which to take crops\n * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is\n * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized\n * coordinates of the box in the boxInd[i]'th image in the batch\n * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range\n * `[0, batch)` that specifies the image that the `i`-th box refers to.\n * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`\n * specifying the size to which all crops are resized to.\n * @param method Optional string from `'bilinear' | 'nearest'`,\n * defaults to bilinear, which specifies the sampling method for resizing\n * @param extrapolationValue A threshold for deciding when to remove boxes based\n * on score. Defaults to 0.\n * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`\n */\n/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */\nfunction cropAndResize_(\n image: Tensor4D|TensorLike,\n boxes: Tensor2D|TensorLike,\n boxInd: Tensor1D|TensorLike,\n cropSize: [number, number],\n method?: 'bilinear'|'nearest',\n extrapolationValue?: number,\n ): Tensor4D {\n const $image = convertToTensor(image, 'image', 'cropAndResize');\n const $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');\n const $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');\n method = method || 'bilinear';\n extrapolationValue = extrapolationValue || 0;\n\n const numBoxes = $boxes.shape[0];\n\n util.assert(\n $image.rank === 4,\n () => 'Error in cropAndResize: image must be rank 4,' +\n `but got rank ${$image.rank}.`);\n util.assert(\n $boxes.rank === 2 && $boxes.shape[1] === 4,\n () => `Error in cropAndResize: boxes must be have size [${numBoxes},4] ` +\n `but had shape ${$boxes.shape}.`);\n util.assert(\n $boxInd.rank === 1 && $boxInd.shape[0] === numBoxes,\n () => `Error in cropAndResize: boxInd must be have size [${numBoxes}] ` +\n `but had shape ${$boxes.shape}.`);\n util.assert(\n cropSize.length === 2,\n () => `Error in cropAndResize: cropSize must be of length 2, but got ` +\n `length ${cropSize.length}.`);\n util.assert(\n cropSize[0] >= 1 && cropSize[1] >= 1,\n () => `cropSize must be atleast [1,1], but was ${cropSize}`);\n util.assert(\n method === 'bilinear' || method === 'nearest',\n () => `method must be bilinear or nearest, but was ${method}`);\n\n const forward: ForwardFunc = (backend, save) =>\n backend.cropAndResize(\n $image, $boxes, $boxInd, cropSize, method, extrapolationValue);\n\n const res = ENGINE.runKernelFunc(\n forward, {images: $image, boxes: $boxes, boxInd: $boxInd}, null /* der */,\n 'CropAndResize', {method, extrapolationValue, cropSize});\n return res;\n}\n\nexport const resizeBilinear = op({resizeBilinear_});\nexport const resizeNearestNeighbor = op({resizeNearestNeighbor_});\nexport const nonMaxSuppression = op({nonMaxSuppression_});\nexport const nonMaxSuppressionAsync = nonMaxSuppressionAsync_;\nexport const nonMaxSuppressionWithScore = op({nonMaxSuppressionWithScore_});\nexport const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;\nexport const cropAndResize = op({cropAndResize_});\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor, Tensor3D, Tensor4D} from '../tensor';\n\nimport {Conv2DInfo} from './conv_util';\n\nexport type Activation = 'linear'|'relu'|'prelu'|'elu'|'relu6';\n\nexport type FusedBatchMatMulConfig = {\n a: Tensor3D,\n b: Tensor3D,\n transposeA: boolean,\n transposeB: boolean,\n bias?: Tensor,\n activation?: Activation,\n preluActivationWeights?: Tensor\n};\n\nexport type FusedConv2DConfig = {\n input: Tensor4D,\n filter: Tensor4D,\n convInfo: Conv2DInfo,\n bias?: Tensor,\n activation?: Activation,\n preluActivationWeights?: Tensor\n};\n\n// Whether we should call fused ops.\nexport const shouldFuse = (gradientDepth: number, activation: Activation) => {\n const gradientMode = gradientDepth > 0;\n return !gradientMode || activation === 'linear';\n};\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {conv2dDerFilter, conv2dDerInput, depthwiseConv2dDerFilter, depthwiseConv2dDerInput} from '../ops/conv';\nimport * as conv_util from '../ops/conv_util';\nimport {op} from '../ops/operation';\nimport {Tensor, Tensor3D, Tensor4D} from '../tensor';\nimport {makeTypesMatch} from '../tensor_util';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {add} from './add';\nimport * as broadcast_util from './broadcast_util';\nimport {conv2d as unfusedConv2d, depthwiseConv2d as unfusedDepthwiseConv2d} from './conv';\nimport {Activation, shouldFuse} from './fused_util';\nimport {matMul as unfusedMatMul} from './matmul';\nimport {elu, prelu, relu, relu6} from './relu_ops';\n\n// Returns gradient for fused activation.\nconst getFusedDyActivation =\n (dy: Tensor, y: Tensor, activation: Activation): Tensor => {\n if (activation == null || activation === 'linear') {\n return dy;\n }\n if (activation === 'relu') {\n return dy.mul(y.step());\n }\n throw new Error(\n `Gradient for activation ${activation} has not been ` +\n `implemented yet.`);\n };\n\n// Returns gradient for fused bias.\nconst getFusedBiasGradient = (bias: Tensor, dyActivation: Tensor): Tensor => {\n let res = dyActivation;\n const reduceAxes =\n broadcast_util.getReductionAxes(bias.shape, dyActivation.shape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape(bias.shape);\n};\n\nconst applyActivation =\n (x: Tensor, activation: Activation, preluActivationWeights?: Tensor):\n Tensor => {\n if (activation === 'linear') {\n return x;\n } else if (activation === 'relu') {\n return relu(x);\n } else if (activation === 'elu') {\n return elu(x);\n } else if (activation === 'relu6') {\n return relu6(x);\n } else if (activation === 'prelu') {\n return prelu(x, preluActivationWeights);\n }\n throw new Error(`Unknown fused activation ${activation}.`);\n };\n\n/**\n * Computes the dot product of two matrices with optional activation and bias.\n *\n * ```js\n * const a = tf.tensor2d([-1, -2], [1, 2]);\n * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n * const bias = tf.tensor2d([1, 2], [1, 2]);\n *\n * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();\n * ```\n *\n * @param obj An object with the following properties:\n * - `a` First matrix in dot product operation.\n * - `b` Second matrix in dot product operation.\n * - `transposeA` If true, `a` is transposed before multiplication.\n * - `transposeB` If true, `b` is transposed before multiplication.\n * - `bias` Matrix to be added to the result.\n * - `activation` Name of activation kernel (defaults to `linear`).\n * - `preluActivationWeights` Tensor of prelu weights.\n */\nfunction fusedMatMul_({\n a,\n b,\n transposeA = false,\n transposeB = false,\n bias,\n activation = 'linear',\n preluActivationWeights\n}: {\n a: T|TensorLike,\n b: T|TensorLike,\n transposeA?: boolean,\n transposeB?: boolean,\n bias?: Tensor|TensorLike,\n activation?: Activation,\n preluActivationWeights?: Tensor\n}): T {\n if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {\n let result = unfusedMatMul(a, b, transposeA, transposeB);\n if (bias != null) {\n result = add(result, bias);\n }\n\n return applyActivation(result, activation, preluActivationWeights) as T;\n }\n\n let $a = convertToTensor(a, 'a', 'fused matMul');\n let $b = convertToTensor(b, 'b', 'fused matMul');\n [$a, $b] = makeTypesMatch($a, $b);\n\n const innerShapeA =\n transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];\n const innerShapeB =\n transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];\n\n const outerShapeA =\n transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];\n const outerShapeB =\n transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];\n\n const outerDimsA = $a.shape.slice(0, -2);\n const outerDimsB = $b.shape.slice(0, -2);\n const batchDimA = util.sizeFromShape(outerDimsA);\n const batchDimB = util.sizeFromShape(outerDimsB);\n\n util.assert(\n $a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank,\n () =>\n `Error in fused matMul: inputs must have the same rank of at least ` +\n `2, got ranks ${$a.rank} and ${$b.rank}.`);\n\n util.assert(\n util.arraysEqual(outerDimsA, outerDimsB),\n () => `Error in fused matMul: outer dimensions (${outerDimsA}) and (` +\n `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` +\n `${$b.shape} must match.`);\n\n util.assert(\n innerShapeA === innerShapeB,\n () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +\n `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +\n `${$b.shape} and transposeA=${transposeA}` +\n ` and transposeB=${transposeB} must match.`);\n\n const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);\n\n const a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) :\n $a.as3D(batchDimA, outerShapeA, innerShapeA);\n const b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) :\n $b.as3D(batchDimB, innerShapeB, outerShapeB);\n\n let $bias: Tensor;\n if (bias != null) {\n $bias = convertToTensor(bias, 'bias', 'fused matMul');\n [$bias] = makeTypesMatch($bias, $a);\n\n broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape);\n }\n\n let $preluActivationWeights: Tensor;\n if (preluActivationWeights != null) {\n $preluActivationWeights = convertToTensor(\n preluActivationWeights, 'prelu weights', 'fused matMul');\n }\n\n const grad = (dy: Tensor3D, saved: Tensor[]) => {\n const [a3D, b3D, y] = saved;\n const dyActivation = getFusedDyActivation(dy, y, activation);\n\n let biasGradient = {};\n if (bias != null) {\n biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)};\n }\n\n if (!transposeA && !transposeB) {\n return Object.assign(\n {\n a: () => dyActivation.matMul(b3D as Tensor3D, false, true),\n b: () => a3D.matMul(dyActivation, true, false)\n },\n biasGradient);\n } else if (!transposeA && transposeB) {\n return Object.assign(\n {\n a: () => dyActivation.matMul(b3D as Tensor3D, false, false),\n b: () => dyActivation.matMul(a3D as Tensor3D, true, false)\n },\n biasGradient);\n } else if (transposeA && !transposeB) {\n return Object.assign(\n {\n a: () => b3D.matMul(dyActivation, false, true),\n b: () => a3D.matMul(dyActivation, false, false)\n },\n biasGradient);\n } else {\n return Object.assign(\n {\n a: () => b3D.matMul(dyActivation, true, true),\n b: () => dyActivation.matMul(a3D as Tensor3D, true, true)\n },\n biasGradient);\n }\n };\n\n const inputs:\n {a: Tensor, b: Tensor,\n bias?: Tensor,\n preluActivationWeights?: Tensor} = {a: a3D, b: b3D};\n if (bias != null) {\n inputs.bias = $bias;\n }\n if (preluActivationWeights != null) {\n inputs.preluActivationWeights = $preluActivationWeights;\n }\n\n const inputsToSave = [a3D, b3D];\n const outputsToSave = [true];\n\n const res = ENGINE.runKernelFunc(\n (backend, save) => {\n const y = backend.fusedBatchMatMul({\n a: a3D,\n b: b3D,\n transposeA,\n transposeB,\n bias: $bias,\n activation,\n preluActivationWeights: $preluActivationWeights\n });\n save([a3D, b3D, y]);\n return y;\n },\n inputs, grad, '_FusedMatMul', {transposeA, transposeB, activation},\n inputsToSave, outputsToSave);\n return res.reshape(outShape) as T;\n}\n\n/**\n * Computes a 2D convolution over the input x, optionally fused with adding a\n * bias and applying an activation.\n *\n * ```js\n * const inputDepth = 2;\n * const inShape = [2, 2, 2, inputDepth];\n * const outputDepth = 2;\n * const fSize = 1;\n * const pad = 0;\n * const strides = 1;\n *\n * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n * 16], inShape);\n * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,\n * outputDepth]);\n *\n * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',\n * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();\n * ```\n *\n * @param obj An object with the following properties:\n * @param x The input tensor, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter, rank 4, of shape\n * `[filterHeight, filterWidth, inDepth, outDepth]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid` output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dataFormat An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels]. Only \"NHWC\" is currently supported.\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n * @param bias Tensor to be added to the result.\n * @param activation Name of activation kernel (defaults to `linear`) to be\n * applied\n * after biasAdd.\n * @param preluActivationWeights Tensor of prelu weights to be applied as part\n * of a `prelu` activation, typically the same shape as `x`.\n */\nfunction fusedConv2d_({\n x,\n filter,\n strides,\n pad,\n dataFormat = 'NHWC',\n dilations = [1, 1],\n dimRoundingMode,\n bias,\n activation = 'linear',\n preluActivationWeights\n}: {\n x: T|TensorLike,\n filter: Tensor4D|TensorLike,\n strides: [number, number]|number,\n pad: 'valid'|'same'|number,\n dataFormat?: 'NHWC'|'NCHW',\n dilations?: [number, number]|number,\n dimRoundingMode?: 'floor'|'round'|'ceil',\n bias?: Tensor|TensorLike,\n activation?: Activation,\n preluActivationWeights?: Tensor\n}): T {\n activation = activation || 'linear';\n if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {\n let result = unfusedConv2d(\n x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);\n if (bias != null) {\n result = add(result, bias);\n }\n\n return applyActivation(result, activation, preluActivationWeights) as T;\n }\n\n const $x = convertToTensor(x, 'x', 'conv2d');\n const $filter = convertToTensor(filter, 'filter', 'conv2d');\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in fused conv2d: input must be rank 4, but got rank ` +\n `${x4D.rank}.`);\n util.assert(\n $filter.rank === 4,\n () => `Error in fused conv2d: filter must be rank 4, but got rank ` +\n `${$filter.rank}.`);\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in fused conv2d: pad must be an integer when using, ` +\n `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n util.assert(\n x4D.shape[3] === $filter.shape[2],\n () => `Error in conv2d: depth of input (${x4D.shape[3]}) must match ` +\n `input depth for filter ${$filter.shape[2]}.`);\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in conv2D: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n util.assert(\n dataFormat === 'NHWC',\n () => `Error in conv2d: got dataFormat of ${\n dataFormat} but only NHWC is currently supported.`);\n\n const convInfo = conv_util.computeConv2DInfo(\n x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);\n\n let $bias: Tensor;\n if (bias != null) {\n $bias = convertToTensor(bias, 'bias', 'fused conv2d');\n [$bias] = makeTypesMatch($bias, $x);\n\n broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);\n }\n\n let $preluActivationWeights: Tensor;\n if (preluActivationWeights != null) {\n $preluActivationWeights = convertToTensor(\n preluActivationWeights, 'prelu weights', 'fused conv2d');\n }\n\n const grad = (dy: Tensor4D, saved: Tensor[]) => {\n const [$filter, x4D, y] = saved as [Tensor4D, Tensor4D, Tensor4D];\n\n const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D;\n\n util.assert(\n conv_util.tupleValuesAreOne(dilations),\n () => 'Error in gradient of fused conv2D: ' +\n `dilation rates greater than 1 ` +\n `are not yet supported in gradients. Got dilations '${dilations}'`);\n\n let biasGradient = {};\n if (bias != null) {\n biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)};\n }\n\n return Object.assign(\n {\n x: () =>\n conv2dDerInput(x4D.shape, dyActivation, $filter, strides, pad),\n filter: () =>\n conv2dDerFilter(x4D, dyActivation, $filter.shape, strides, pad)\n },\n biasGradient);\n };\n\n const inputs: {\n x: Tensor,\n filter: Tensor,\n bias?: Tensor,\n preluActivationWeights?: Tensor\n } = {x: x4D, filter: $filter};\n if (bias != null) {\n inputs.bias = $bias;\n }\n if (preluActivationWeights != null) {\n inputs.preluActivationWeights = $preluActivationWeights;\n }\n\n const inputsToSave = [$filter, x4D];\n const outputsToSave = [true]; // Save the only output.\n const res = ENGINE.runKernelFunc(\n (backend, save) => {\n const res = backend.fusedConv2d({\n input: x4D,\n filter: $filter,\n convInfo,\n bias: $bias,\n activation,\n preluActivationWeights: $preluActivationWeights\n });\n save([$filter, x4D, res]);\n return res;\n },\n inputs, grad, 'FusedConv2D', {convInfo, activation}, inputsToSave,\n outputsToSave);\n\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n\n return res as T;\n}\n\n/**\n * Computes depthwise 2D convolution, optionally fused with adding a\n * bias and applying an activation.\n *\n * Given a 4D `input` array and a `filter` array of shape\n * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing\n * `inChannels` convolutional filters of depth 1, this op applies a\n * different filter to each input channel (expanding from 1 channel to\n * `channelMultiplier` channels for each), then concatenates the results\n * together. The output has `inChannels * channelMultiplier` channels.\n *\n * See\n * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](\n * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)\n * for more details.\n *\n * @param obj An object with the following properties:\n * @param x The input tensor, of rank 4 or rank 3, of shape\n * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter tensor, rank 4, of shape\n * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`. If strides is a single number, then `strideHeight ==\n * strideWidth`.\n * @param pad The type of padding algorithm.\n * - `same` and stride 1: output will be of same size as input,\n * regardless of filter size.\n * - `valid`: output will be smaller than input if filter is larger\n * than 1x1.\n * - For more info, see this guide:\n * [https://www.tensorflow.org/api_guides/python/nn#Convolution](\n * https://www.tensorflow.org/api_guides/python/nn#Convolution)\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n * in which we sample input values across the height and width dimensions\n * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single\n * number, then `dilationHeight == dilationWidth`. If it is greater than\n * 1, then all values of `strides` must be 1.\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n * \"NHWC\". Specify the data format of the input and output data. With the\n * default format \"NHWC\", the data is stored in the order of: [batch,\n * height, width, channels]. Only \"NHWC\" is currently supported.\n * @param dimRoundingMode The rounding mode used when computing output\n * dimensions if pad is a number. If none is provided, it will not round\n * and error if the output is of fractional size.\n * @param bias Tensor to be added to the result.\n * @param activation Name of activation kernel (defaults to `linear`).\n * @param preluActivationWeights Tensor of prelu weights to be applied as part\n * of a `prelu` activation, typically the same shape as `x`.\n */\nfunction fusedDepthwiseConv2d_({\n x,\n filter,\n strides,\n pad,\n dataFormat = 'NHWC',\n dilations = [1, 1],\n dimRoundingMode,\n bias,\n activation = 'linear',\n preluActivationWeights\n}: {\n x: T|TensorLike,\n filter: Tensor4D|TensorLike,\n strides: [number, number]|number,\n pad: 'valid'|'same'|number,\n dataFormat?: 'NHWC'|'NCHW',\n dilations?: [number, number]|number,\n dimRoundingMode?: 'floor'|'round'|'ceil',\n bias?: Tensor|TensorLike,\n activation?: Activation,\n preluActivationWeights?: Tensor\n}): T {\n if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {\n let result = unfusedDepthwiseConv2d(\n x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);\n if (bias != null) {\n result = add(result, bias);\n }\n\n return applyActivation(result, activation, preluActivationWeights) as T;\n }\n\n const $x = convertToTensor(x, 'x', 'depthwiseConv2d');\n const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');\n\n let x4D = $x as Tensor4D;\n let reshapedTo4D = false;\n if ($x.rank === 3) {\n reshapedTo4D = true;\n x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]);\n }\n util.assert(\n x4D.rank === 4,\n () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` +\n `rank ${x4D.rank}.`);\n util.assert(\n $filter.rank === 4,\n () => `Error in fused depthwiseConv2d: filter must be rank 4, ` +\n `but got rank ${$filter.rank}.`);\n util.assert(\n x4D.shape[3] === $filter.shape[2],\n () => `Error in fused depthwiseConv2d: number of input channels ` +\n `(${x4D.shape[3]}) must match the inChannels dimension in ` +\n `filter ${$filter.shape[2]}.`);\n if (dilations == null) {\n dilations = [1, 1];\n }\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () =>\n 'Error in fused depthwiseConv2d: Either strides or dilations must ' +\n `be 1. Got strides ${strides} and dilations '${dilations}'`);\n\n if (dimRoundingMode != null) {\n util.assert(\n util.isInt(pad as number),\n () => `Error in fused depthwiseConv2d: pad must be an integer when ` +\n `using dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);\n }\n\n const convInfo = conv_util.computeConv2DInfo(\n x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode,\n true /* depthwise */);\n\n let $bias: Tensor;\n if (bias != null) {\n $bias = convertToTensor(bias, 'bias', 'fused conv2d');\n [$bias] = makeTypesMatch($bias, $x);\n\n broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);\n }\n\n let $preluActivationWeights: Tensor;\n if (preluActivationWeights != null) {\n $preluActivationWeights = convertToTensor(\n preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');\n }\n\n const grad = (dy: Tensor4D, saved: Tensor[]) => {\n util.assert(\n conv_util.tupleValuesAreOne(dilations),\n () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' +\n `greater than 1 are not yet supported. Got dilations ` +\n `'${dilations}'`);\n const [$filter, x4D, y] = saved;\n\n const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D;\n\n let biasGradient = {};\n if (bias != null) {\n biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)};\n }\n\n return Object.assign(\n {\n x: () => depthwiseConv2dDerInput(\n (x4D as Tensor4D).shape, dyActivation, $filter as Tensor4D,\n convInfo),\n filter: () => depthwiseConv2dDerFilter(\n x4D as Tensor4D, dyActivation, ($filter as Tensor4D).shape,\n convInfo),\n },\n biasGradient);\n };\n\n const inputs: {\n x: Tensor,\n filter: Tensor,\n bias?: Tensor,\n preluActivationWeights?: Tensor\n } = {x: x4D, filter: $filter};\n if (bias != null) {\n inputs.bias = $bias;\n }\n if (preluActivationWeights != null) {\n inputs.preluActivationWeights = $preluActivationWeights;\n }\n\n const inputsToSave = [$filter, x4D];\n const outputsToSave = [true];\n const res = ENGINE.runKernelFunc(\n (backend, save) => {\n const res = backend.fusedDepthwiseConv2D({\n input: x4D,\n filter: $filter,\n convInfo,\n bias: $bias,\n activation,\n preluActivationWeights: $preluActivationWeights\n });\n save([$filter, x4D, res]);\n return res;\n },\n inputs, grad, 'FusedDepthwiseConv2D', {convInfo, activation},\n inputsToSave, outputsToSave);\n if (reshapedTo4D) {\n return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;\n }\n return res as T;\n}\n\nexport const matMul = op({fusedMatMul_});\nexport const conv2d = op({fusedConv2d_});\nexport const depthwiseConv2d = op({fusedDepthwiseConv2d_});\n\nexport {Activation};\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {TensorInfo} from '../../kernel_registry';\nimport {assert} from '../../util';\n\nexport function assertNotComplex(\n tensor: TensorInfo|TensorInfo[], opName: string): void {\n if (!Array.isArray(tensor)) {\n tensor = [tensor];\n }\n tensor.forEach(t => {\n if (t != null) {\n assert(\n t.dtype !== 'complex64',\n () => `${opName} does not support complex64 tensors.`);\n }\n });\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Conv2DInfo} from '../../ops/conv_util';\nimport * as ops from '../../ops/ops';\nimport {buffer} from '../../ops/ops';\nimport {TensorBuffer} from '../../tensor';\nimport {DataType, Rank, TypedArray} from '../../types';\n\nexport function pool(\n xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[],\n convInfo: Conv2DInfo, poolType: 'max'|'avg'): TensorBuffer {\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n\n const initialValue =\n (poolType === 'max' ? Number.NEGATIVE_INFINITY :\n Number.POSITIVE_INFINITY);\n\n const output = ops.buffer(convInfo.outShape, dtype);\n const outputVals = output.values;\n\n const outputBatchStrides =\n convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];\n const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];\n const outputColStrides = convInfo.outShape[3];\n\n for (let b = 0; b < convInfo.batchSize; ++b) {\n const outputBatchOffset = b * outputBatchStrides;\n const inputBatchOffset = b * strides[0];\n for (let d = 0; d < convInfo.inChannels; ++d) {\n for (let yR = 0; yR < convInfo.outHeight; ++yR) {\n const xRCorner = yR * strideHeight - padTop;\n const xRMin = Math.max(0, xRCorner);\n const xRMax =\n Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);\n const outputRowOffset = outputBatchOffset + yR * outputRowStrides;\n for (let yC = 0; yC < convInfo.outWidth; ++yC) {\n const xCCorner = yC * strideWidth - padLeft;\n const xCMin = Math.max(0, xCCorner);\n const xCMax =\n Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);\n let minMaxValue = initialValue;\n let avgValue = 0;\n let count = 0;\n for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {\n const xROffset = inputBatchOffset + xR * strides[1];\n for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {\n const xCOffset = xROffset + xC * strides[2];\n const pixel = xValues[xCOffset + d];\n if ((poolType === 'max' && pixel > minMaxValue)) {\n minMaxValue = pixel;\n } else if (poolType === 'avg') {\n avgValue += pixel;\n count++;\n }\n }\n if (isNaN(minMaxValue)) {\n break;\n }\n }\n const outputOffset = outputRowOffset + yC * outputColStrides + d;\n outputVals[outputOffset] =\n poolType === 'avg' ? avgValue / count : minMaxValue;\n }\n }\n }\n }\n return output;\n}\n\nexport function maxPoolPositions(\n xValues: TypedArray, xShape: number[], dtype: DataType,\n convInfo: Conv2DInfo, flattenPositions = false,\n includeBatchInIndex = false): TensorBuffer {\n const maxPositions = ops.buffer(convInfo.outShape, 'int32');\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n\n const xBuf = buffer(xShape, dtype, xValues);\n for (let b = 0; b < convInfo.batchSize; ++b) {\n for (let d = 0; d < convInfo.inChannels; ++d) {\n for (let yR = 0; yR < convInfo.outHeight; ++yR) {\n const xRCorner = yR * strideHeight - padTop;\n let xRMin = xRCorner;\n while (xRMin < 0) {\n xRMin += dilationHeight;\n }\n // const xRMin = Math.max(0, xRCorner);\n const xRMax =\n Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);\n for (let yC = 0; yC < convInfo.outWidth; ++yC) {\n const xCCorner = yC * strideWidth - padLeft;\n let xCMin = xCCorner;\n while (xCMin < 0) {\n xCMin += dilationWidth;\n }\n const xCMax =\n Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);\n let maxValue = Number.NEGATIVE_INFINITY;\n let maxPosition = -1;\n\n for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {\n const wR = xR - xRCorner;\n for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {\n const wC = xC - xCCorner;\n const pixel = xBuf.get(b, xR, xC, d);\n if (pixel > maxValue) {\n maxValue = pixel as number;\n if (flattenPositions) {\n maxPosition = includeBatchInIndex ?\n ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *\n convInfo.inChannels +\n d :\n (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;\n } else {\n maxPosition = wR * effectiveFilterWidth + wC;\n }\n }\n }\n }\n maxPositions.set(maxPosition, b, yR, yC, d);\n }\n }\n }\n }\n return maxPositions;\n}\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as seedrandom from 'seedrandom';\n\nimport {ENGINE} from '../../engine';\nimport {env} from '../../environment';\nimport {warn} from '../../log';\nimport * as array_ops_util from '../../ops/array_ops_util';\nimport * as axis_util from '../../ops/axis_util';\nimport * as broadcast_util from '../../ops/broadcast_util';\nimport {complex, imag, real} from '../../ops/complex_ops';\nimport * as concat_util from '../../ops/concat_util';\nimport {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';\nimport {div} from '../../ops/div';\nimport * as erf_util from '../../ops/erf_util';\nimport {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';\nimport * as gather_nd_util from '../../ops/gather_nd_util';\nimport * as ops from '../../ops/ops';\nimport {buffer, scalar, tensor, tensor4d} from '../../ops/ops';\nimport * as scatter_nd_util from '../../ops/scatter_nd_util';\nimport * as selu_util from '../../ops/selu_util';\nimport {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';\nimport {transpose} from '../../ops/transpose';\nimport {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';\nimport {BackendValues, DataType, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../../types';\nimport * as util from '../../util';\nimport {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';\nimport {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';\nimport * as backend_util from '../backend_util';\nimport * as complex_util from '../complex_util';\nimport {nonMaxSuppressionV3} from '../non_max_suppression_impl';\nimport {split} from '../split_shared';\nimport {tile} from '../tile_impl';\nimport {topkImpl} from '../topk_impl';\nimport {whereImpl} from '../where_impl';\n\nimport {assertNotComplex} from './cpu_util';\nimport {maxPoolPositions, pool} from './pool_utils';\n\nfunction mapActivation(\n backend: MathBackendCPU, x: Tensor, activation: Activation,\n preluActivationWeights?: Tensor): Tensor {\n if (activation === 'linear') {\n return backend.linear(x);\n } else if (activation === 'relu') {\n return backend.relu(x);\n } else if (activation === 'elu') {\n return backend.elu(x);\n } else if (activation === 'relu6') {\n return backend.relu6(x);\n } else if (activation === 'prelu') {\n return backend.prelu(x, preluActivationWeights);\n }\n throw new Error(\n `Activation ${activation} has not been implemented for the CPU backend.`);\n}\n\nexport interface TensorData {\n values?: BackendValues;\n dtype: D;\n // For complex numbers, the real and imaginary parts are stored as their own\n // individual tensors, with a parent joining the two with the\n // complexTensors field.\n // TODO(smilkov): Replace Tensor with TensorInfo when you modularize ops\n // that work with complex tensors.\n complexTensors?: {real: Tensor, imag: Tensor};\n}\n\nexport class MathBackendCPU extends KernelBackend {\n public blockSize = 48;\n\n data: DataStorage>;\n private firstUse = true;\n\n constructor() {\n super();\n this.data = new DataStorage(this, ENGINE);\n }\n\n write(values: BackendValues, shape: number[], dtype: DataType): DataId {\n if (this.firstUse) {\n this.firstUse = false;\n if (env().get('IS_NODE')) {\n warn(\n '\\n============================\\n' +\n 'Hi there 👋. Looks like you are running TensorFlow.js in ' +\n 'Node.js. To speed things up dramatically, install our node ' +\n 'backend, which binds to TensorFlow C++, by running ' +\n 'npm i @tensorflow/tfjs-node, ' +\n 'or npm i @tensorflow/tfjs-node-gpu if you have CUDA. ' +\n 'Then call require(\\'@tensorflow/tfjs-node\\'); (-gpu ' +\n 'suffix for CUDA) at the start of your program. ' +\n 'Visit https://github.com/tensorflow/tfjs-node for more details.' +\n '\\n============================');\n }\n }\n const dataId = {};\n this.data.set(dataId, {values, dtype});\n return dataId;\n }\n\n move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):\n void {\n this.data.set(dataId, {values, dtype});\n }\n\n numDataIds(): number {\n return this.data.numDataIds();\n }\n\n async read(dataId: DataId): Promise {\n return this.readSync(dataId);\n }\n readSync(dataId: DataId): BackendValues {\n const {dtype, complexTensors} = this.data.get(dataId);\n if (dtype === 'complex64') {\n const realValues =\n this.readSync(complexTensors.real.dataId) as Float32Array;\n const imagValues =\n this.readSync(complexTensors.imag.dataId) as Float32Array;\n return complex_util.mergeRealAndImagArrays(realValues, imagValues);\n }\n return this.data.get(dataId).values;\n }\n\n private bufferSync(t: Tensor): TensorBuffer {\n const data = this.readSync(t.dataId);\n let decodedData = data as DataValues;\n if (t.dtype === 'string') {\n try {\n // Decode the bytes into string.\n decodedData = (data as Uint8Array[]).map(d => util.decodeString(d));\n } catch {\n throw new Error('Failed to decode encoded string bytes into utf-8');\n }\n }\n return buffer(t.shape, t.dtype, decodedData) as TensorBuffer;\n }\n\n private makeOutput(\n values: BackendValues, shape: number[], dtype: DataType): T {\n const dataId = this.write(values, shape, dtype);\n return ENGINE.makeTensorFromDataId(dataId, shape, dtype, this) as T;\n }\n\n disposeData(dataId: DataId): void {\n if (this.data.has(dataId)) {\n const {complexTensors} = this.data.get(dataId);\n if (complexTensors != null) {\n complexTensors.real.dispose();\n complexTensors.imag.dispose();\n }\n this.data.delete(dataId);\n }\n }\n\n async time(f: () => void): Promise {\n const start = now();\n f();\n const kernelMs = now() - start;\n return {kernelMs};\n }\n\n memory() {\n return {\n // Unreliable due to automatic gc. The numbers above are cumulative.\n unreliable: true,\n reasons:\n ['The reported memory is an upper bound. Due to automatic garbage ' +\n 'collection, the true allocated memory may be less.']\n };\n }\n\n complex(real: T, imag: T): T {\n const result = this.makeOutput(null, real.shape, 'complex64');\n\n const resultData = this.data.get(result.dataId);\n // The backend owns the reference to the underlying real and imaginary\n // clones. These will explicitly get disposed when the complex tensor is\n // disposed.\n resultData.complexTensors = {\n real: ENGINE.keep(real.clone()),\n imag: ENGINE.keep(imag.clone())\n };\n\n return result as T;\n }\n real(input: T): T {\n const resultData = this.data.get(input.dataId);\n return resultData.complexTensors.real.clone() as T;\n }\n imag(input: T): T {\n const resultData = this.data.get(input.dataId);\n return resultData.complexTensors.imag.clone() as T;\n }\n\n slice(x: T, begin: number[], size: number[]): T {\n assertNotComplex(x, 'slice');\n\n const isContinous = isSliceContinous(x.shape, begin, size);\n if (isContinous) {\n const flatOffset = computeFlatOffset(begin, x.strides);\n const length = util.sizeFromShape(size);\n const vals = this.readSync(x.dataId) as TypedArray;\n return tensor(\n vals.subarray(flatOffset, flatOffset + length), size,\n x.dtype) as T;\n }\n\n const buffer = ops.buffer(size, x.dtype);\n const xBuf = this.bufferSync(x);\n for (let i = 0; i < buffer.size; ++i) {\n const loc = buffer.indexToLoc(i);\n const xLoc = loc.map((idx, j) => idx + begin[j]);\n buffer.values[i] = xBuf.get(...xLoc);\n }\n return buffer.toTensor() as T;\n }\n\n stridedSlice(\n x: T, begin: number[], end: number[], strides: number[]): T {\n assertNotComplex(x, 'stridedSlice');\n\n const outShape = computeOutShape(begin, end, strides);\n\n if (outShape.some(axis => axis === 0)) {\n return ops.tensor([], outShape) as T;\n }\n\n const buffer = ops.buffer(outShape, x.dtype);\n const xBuf = this.bufferSync(x);\n for (let i = 0; i < buffer.size; i++) {\n const loc = buffer.indexToLoc(i);\n\n const newLoc: number[] = new Array(loc.length);\n for (let j = 0; j < newLoc.length; j++) {\n newLoc[j] = loc[j] * strides[j] + begin[j];\n }\n buffer.set(xBuf.get(...newLoc), ...loc);\n }\n\n return buffer.toTensor() as T;\n }\n\n diag(x: Tensor): Tensor {\n const xVals = this.readSync(x.dataId) as TypedArray;\n const buffer = ops.buffer([x.size, x.size], x.dtype);\n const vals = buffer.values;\n for (let i = 0; i < xVals.length; i++) {\n vals[i * x.size + i] = xVals[i];\n }\n return buffer.toTensor();\n }\n\n unstack(x: Tensor, axis: number): Tensor[] {\n const num = x.shape[axis];\n const outShape: number[] = new Array(x.rank - 1);\n let outIndex = 0;\n for (let i = 0; i < x.rank; i++) {\n if (i !== axis) {\n outShape[outIndex++] = x.shape[i];\n }\n }\n\n const begin = new Array(x.rank).fill(0);\n const size = x.shape.slice();\n size[axis] = 1;\n const res = new Array(num);\n for (let i = 0; i < res.length; i++) {\n begin[axis] = i;\n res[i] = this.slice(x, begin, size).reshape(outShape);\n }\n return res;\n }\n\n reverse(x: T, axis: number[]): T {\n assertNotComplex(x, 'reverse');\n\n const buffer = ops.buffer(x.shape, x.dtype);\n const xBuf = this.bufferSync(x);\n\n for (let i = 0; i < buffer.size; i++) {\n const outLoc = buffer.indexToLoc(i);\n const inLoc = outLoc.slice();\n axis.forEach(ax => inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]);\n buffer.set(xBuf.get(...inLoc), ...outLoc);\n }\n\n return buffer.toTensor() as T;\n }\n\n concat(tensors: Tensor[], axis: number): Tensor {\n if (tensors[0].dtype === 'complex64') {\n const reals = tensors.map((t) => real(t));\n const imags = tensors.map((t) => imag(t));\n return complex(this.concat(reals, axis), this.concat(imags, axis));\n }\n const tensors2D = tensors.map(t => {\n const innerSize = util.sizeFromShape(t.shape.slice(axis));\n return t.as2D(-1, innerSize);\n });\n const outShape =\n concat_util.computeOutShape(tensors2D.map(t => t.shape), 1 /* axis */);\n const values =\n ops.buffer(outShape as [number, number], tensors[0].dtype as 'float32')\n .values;\n if (tensors2D[0].shape[0] === 1) {\n // Use built-in TypedArray.set() method for speed.\n let offset = 0;\n tensors2D.forEach(t => {\n values.set(this.readSync(t.dataId) as TypedArray, offset);\n offset += t.size;\n });\n } else {\n let colOffset = 0;\n tensors2D.forEach(t => {\n const tVals = this.readSync(t.dataId) as TypedArray;\n let tIdx = 0;\n for (let row = 0; row < t.shape[0]; ++row) {\n const resIdx = row * outShape[1] + colOffset;\n for (let col = 0; col < t.shape[1]; ++col) {\n values[resIdx + col] = tVals[tIdx++];\n }\n }\n colOffset += t.shape[1];\n });\n }\n const finalOutShape =\n concat_util.computeOutShape(tensors.map(t => t.shape), axis);\n return tensor(values, finalOutShape, tensors[0].dtype);\n }\n\n neg(x: T): T {\n assertNotComplex(x, 'neg');\n\n return this.multiply(ops.scalar(-1), x) as T;\n }\n\n add(a: Tensor, b: Tensor): Tensor {\n if (a.dtype === 'complex64' || b.dtype === 'complex64') {\n return this.broadcastedBinaryComplexOp(\n a.cast('complex64'), b.cast('complex64'),\n (aReal, aImag, bReal, bImag) => {\n return {real: aReal + bReal, imag: aImag + bImag};\n });\n }\n\n return this.broadcastedBinaryOp(\n a, b, upcastType(a.dtype, b.dtype),\n (aValue, bValue) => aValue + bValue);\n }\n\n addN(tensors: T[]): T {\n assertNotComplex(tensors, 'addN');\n\n const vals = tensors.map(t => this.readSync(t.dataId) as TypedArray);\n const result = ops.buffer(tensors[0].shape, tensors[0].dtype as 'float32');\n const resultVals = result.values;\n for (let i = 0; i < tensors.length; i++) {\n const currVals = vals[i];\n for (let j = 0; j < resultVals.length; j++) {\n resultVals[j] += currVals[j];\n }\n }\n return result.toTensor() as T;\n }\n\n softmax(logits: T, dim: number): T {\n const axes = util.parseAxisParam([dim], logits.shape);\n const maxLogit = this.max(logits, axes);\n const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes);\n const a = this.subtract(logits, maxLogit.reshape(expandedShape));\n const b = this.exp(a);\n const sumExp = this.sum(b, axes).reshape(expandedShape);\n\n // TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel\n // modularization.\n return div(b, sumExp);\n }\n\n subtract(a: Tensor, b: Tensor): Tensor {\n if (a.dtype === 'complex64' || b.dtype === 'complex64') {\n return this.broadcastedBinaryComplexOp(\n a.cast('complex64'), b.cast('complex64'),\n (aReal, aImag, bReal, bImag) => {\n return {real: aReal - bReal, imag: aImag - bImag};\n });\n }\n\n return this.broadcastedBinaryOp(\n a, b, upcastType(a.dtype, b.dtype),\n (aValue, bValue) => aValue - bValue);\n }\n\n pow(a: T, b: Tensor): T {\n assertNotComplex([a, b], 'pow');\n\n return this.broadcastedBinaryOp(\n a, b, a.dtype, (aValue, bValue) => Math.pow(aValue, bValue)) as\n T;\n }\n\n batchMatMul(\n a: Tensor3D, b: Tensor3D, transposeA: boolean,\n transposeB: boolean): Tensor3D {\n assertNotComplex([a, b], 'matMul');\n\n const sharedDim = transposeA ? a.shape[1] : a.shape[2];\n const leftDim = transposeA ? a.shape[2] : a.shape[1];\n const rightDim = transposeB ? b.shape[1] : b.shape[2];\n const batchDim = a.shape[0];\n\n const aValues = this.readSync(a.dataId) as TypedArray;\n const bValues = this.readSync(b.dataId) as TypedArray;\n const [aBatch, aOuterStep, aInnerStep] = transposeA ?\n [a.strides[0], 1, a.strides[1]] :\n [a.strides[0], a.strides[1], 1];\n const [bInnerStep, bOuterStep, bBatch] = transposeB ?\n [1, b.strides[1], b.strides[0]] :\n [b.strides[1], 1, b.strides[0]];\n\n const size = leftDim * rightDim;\n const result = buffer([batchDim, leftDim, rightDim], a.dtype);\n const resVals = result.values as TypedArray;\n const blockSize = this.blockSize;\n\n for (let b = 0; b < batchDim; b++) {\n for (let i0 = 0; i0 < leftDim; i0 += blockSize) {\n for (let j0 = 0; j0 < rightDim; j0 += blockSize) {\n for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {\n // for when blockSize doesn't evenly divide the input\n const iBlock = Math.min(i0 + blockSize, leftDim);\n const jBlock = Math.min(j0 + blockSize, rightDim);\n const kBlock = Math.min(k0 + blockSize, sharedDim);\n\n for (let i = i0; i < iBlock; i++) {\n for (let j = j0; j < jBlock; j++) {\n let sum = 0.0;\n\n for (let k = k0; k < kBlock; k++) {\n sum += aValues[b * aBatch + i * aOuterStep + k * aInnerStep] *\n bValues[k * bInnerStep + j * bOuterStep + b * bBatch];\n }\n resVals[b * size + (i * rightDim + j)] += sum;\n }\n }\n }\n }\n }\n }\n return result.toTensor() as Tensor3D;\n }\n\n fusedBatchMatMul(\n {a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:\n FusedBatchMatMulConfig): Tensor3D {\n let result = this.batchMatMul(a, b, transposeA, transposeB);\n if (bias) {\n result = this.add(result, bias) as Tensor3D;\n }\n if (activation) {\n result =\n mapActivation(this, result, activation, preluActivationWeights) as\n Tensor3D;\n }\n return result;\n }\n\n multiply(a: Tensor, b: Tensor): Tensor {\n if (a.dtype === 'complex64' || b.dtype === 'complex64') {\n return this.broadcastedBinaryComplexOp(\n a.cast('complex64'), b.cast('complex64'),\n (aReal, aImag, bReal, bImag) => {\n return {\n real: aReal * bReal - aImag * bImag,\n imag: aReal * bImag + aImag * bReal\n };\n });\n }\n\n return this.broadcastedBinaryOp(\n a, b, upcastType(a.dtype, b.dtype),\n (aValue, bValue) => aValue * bValue);\n }\n\n floorDiv(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'floorDiv');\n\n const op = (a: number, b: number) => Math.floor(a / b);\n const outputDtype = 'int32';\n return this.broadcastedBinaryOp(a, b, outputDtype, op);\n }\n\n sum(x: Tensor, axes: number[]): Tensor {\n assertNotComplex(x, 'sum');\n\n axis_util.assertAxesAreInnerMostDims('sum', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const resultDtype = upcastType(x.dtype, 'int32');\n const result = ops.zeros(outShape, resultDtype);\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let sum = 0;\n for (let j = 0; j < reduceSize; ++j) {\n sum += aVals[offset + j];\n }\n vals[i] = sum;\n }\n return result;\n }\n\n prod(x: Tensor, axes: number[]): Tensor {\n assertNotComplex(x, 'sum');\n\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const resultDtype = upcastType(x.dtype, 'int32');\n const result = ops.zeros(outShape, resultDtype);\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let prod = 1;\n for (let j = 0; j < reduceSize; ++j) {\n prod *= aVals[offset + j];\n }\n vals[i] = prod;\n }\n return result;\n }\n\n unsortedSegmentSum(\n x: T, segmentIds: Tensor1D, numSegments: number): Tensor {\n assertNotComplex(x, 'unsortedSegmentSum');\n\n const res = [];\n\n // Reshape the segment id's so that they can be broadcast with\n // x. The new shape should be [segmentIds.shape, 1, ..., 1]\n const numIters = x.rank - segmentIds.rank;\n for (let i = 0; i < numIters; ++i) {\n segmentIds = segmentIds.expandDims(i + 1);\n }\n\n for (let i = 0; i < numSegments; ++i) {\n const segmentId = ops.scalar(i, 'int32');\n const mask = ops.equal(segmentId, segmentIds).asType('float32');\n const sum = mask.mul(x).sum(0);\n res.push(sum);\n }\n\n return ops.stack(res);\n }\n\n argMin(x: Tensor, axis: number): Tensor {\n assertNotComplex(x, 'argMin');\n\n const axes = [axis];\n axis_util.assertAxesAreInnerMostDims('argMin', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const result = ops.zeros(outShape, 'int32');\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let min = aVals[offset];\n let minIndex = 0;\n for (let j = 0; j < reduceSize; ++j) {\n const value = aVals[offset + j];\n if (value < min) {\n min = value;\n minIndex = j;\n }\n }\n vals[i] = minIndex;\n }\n return result;\n }\n\n argMax(x: Tensor, axis: number): Tensor {\n assertNotComplex(x, 'argMax');\n\n const axes = [axis];\n axis_util.assertAxesAreInnerMostDims('argMax', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const result = ops.zeros(outShape, 'int32');\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let max = aVals[offset];\n let maxIndex = 0;\n for (let j = 0; j < reduceSize; ++j) {\n const value = aVals[offset + j];\n if (value > max) {\n max = value;\n maxIndex = j;\n }\n }\n vals[i] = maxIndex;\n }\n return result;\n }\n\n cumsum(x: Tensor, axis: number, exclusive: boolean, reverse: boolean):\n Tensor {\n assertNotComplex(x, 'cumsum');\n\n if (axis !== x.rank - 1) {\n throw new Error(\n `backend.cumsum in CPU expects an inner-most axis=${x.rank - 1} ` +\n `but got axis=${axis}`);\n }\n const resultDtype = upcastType(x.dtype, 'int32');\n const result = ops.zeros(x.shape, resultDtype);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n const finalDim = x.shape[x.rank - 1];\n const indexAdjuster = reverse ?\n (i: number, j: number) => i + finalDim - j - 1 :\n (i: number, j: number) => i + j;\n for (let i = 0; i < aVals.length; i += finalDim) {\n for (let j = 0; j < finalDim; j++) {\n const idx = indexAdjuster(i, j);\n if (j === 0) {\n vals[idx] = exclusive ? 0 : aVals[idx];\n } else {\n const prevIdx = indexAdjuster(i, j - 1);\n vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :\n aVals[idx] + vals[prevIdx];\n }\n }\n }\n return result;\n }\n\n equal(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'equal');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return (aVal === bVal) ? 1 : 0;\n });\n }\n\n notEqual(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'notEqual');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return (aVal !== bVal) ? 1 : 0;\n });\n }\n\n less(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'less');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return (aVal < bVal) ? 1 : 0;\n });\n }\n\n lessEqual(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'lessEqual');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return (aVal <= bVal) ? 1 : 0;\n });\n }\n\n greater(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'greater');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return (aVal > bVal) ? 1 : 0;\n });\n }\n\n greaterEqual(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'greaterEqual');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return (aVal >= bVal) ? 1 : 0;\n });\n }\n\n logicalNot(x: T): T {\n assertNotComplex(x, 'logicalNot');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Uint8Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n newValues[i] = values[i] ? 0 : 1;\n }\n return this.makeOutput(newValues, x.shape, 'bool');\n }\n\n logicalAnd(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'logicalAnd');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return aVal && bVal;\n });\n }\n\n logicalOr(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'logicalOr');\n\n return this.broadcastedBinaryOp(a, b, 'bool', (aVal, bVal) => {\n return aVal || bVal;\n });\n }\n\n select(condition: Tensor, a: Tensor, b: Tensor): Tensor {\n assertNotComplex([condition, a, b], 'select');\n\n const values = this.readSync(condition.dataId) as TypedArray;\n const aValues = this.readSync(a.dataId) as TypedArray;\n const bValues = this.readSync(b.dataId) as TypedArray;\n const result = ops.zeros(a.shape, upcastType(a.dtype, b.dtype));\n const newValues = this.readSync(result.dataId) as TypedArray;\n let index = 0;\n const offset = condition.rank === 0 || condition.rank > 1 || a.rank === 1 ?\n 1 :\n util.sizeFromShape(a.shape.slice(1));\n\n for (let i = 0; i < values.length; i++) {\n for (let j = 0; j < offset; j++) {\n if (values[i] === 1) {\n newValues[index++] = aValues[i];\n } else {\n newValues[index++] = bValues[i];\n }\n }\n }\n\n return result;\n }\n\n where(condition: Tensor): Tensor2D {\n assertNotComplex([condition], 'where');\n\n const condVals = this.readSync(condition.dataId) as TypedArray;\n return whereImpl(condition.shape, condVals);\n }\n\n topk(x: T, k: number, sorted: boolean): [T, T] {\n assertNotComplex(x, 'topk');\n\n const xVals = this.readSync(x.dataId) as TypedArray;\n return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);\n }\n\n min(x: Tensor, axes: number[]): Tensor {\n assertNotComplex(x, 'min');\n\n axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const result = ops.zeros(outShape, x.dtype);\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let min = aVals[offset];\n for (let j = 0; j < reduceSize; ++j) {\n const value = aVals[offset + j];\n if (value < min) {\n min = value;\n }\n }\n vals[i] = min;\n }\n return result;\n }\n\n minimum(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'minimum');\n\n return this.broadcastedBinaryOp(\n a, b, a.dtype, (aVal, bVal) => Math.min(aVal, bVal));\n }\n\n mod(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'mod');\n\n return this.broadcastedBinaryOp(a, b, a.dtype, (aVal, bVal) => {\n const rem = aVal % bVal;\n if ((aVal < 0 && bVal < 0) || (aVal >= 0 && bVal >= 0)) {\n return rem;\n } else {\n return (rem + bVal) % bVal;\n }\n });\n }\n\n max(x: Tensor, axes: number[]): Tensor {\n assertNotComplex(x, 'max');\n\n axis_util.assertAxesAreInnerMostDims('max', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const result = ops.zeros(outShape, x.dtype);\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let max = aVals[offset];\n for (let j = 0; j < reduceSize; ++j) {\n const value = aVals[offset + j];\n if (value > max) {\n max = value;\n }\n }\n vals[i] = max;\n }\n return result;\n }\n\n maximum(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'maximum');\n\n return this.broadcastedBinaryOp(\n a, b, a.dtype, (aVal, bVal) => Math.max(aVal, bVal));\n }\n\n all(x: Tensor, axes: number[]): Tensor {\n assertNotComplex(x, 'all');\n\n axis_util.assertAxesAreInnerMostDims('all', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const result = ops.zeros(outShape, x.dtype);\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let all = aVals[offset];\n for (let j = 0; j < reduceSize; ++j) {\n const value = aVals[offset + j];\n all = all && value;\n }\n vals[i] = all;\n }\n return result;\n }\n\n any(x: Tensor, axes: number[]): Tensor {\n assertNotComplex(x, 'any');\n\n axis_util.assertAxesAreInnerMostDims('any', axes, x.rank);\n const [outShape, reduceShape] =\n axis_util.computeOutAndReduceShapes(x.shape, axes);\n const result = ops.zeros(outShape, x.dtype);\n const reduceSize = util.sizeFromShape(reduceShape);\n const vals = this.readSync(result.dataId) as TypedArray;\n\n const aVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < vals.length; ++i) {\n const offset = i * reduceSize;\n let anyVal = aVals[offset];\n for (let j = 0; j < reduceSize; ++j) {\n const value = aVals[offset + j];\n anyVal = anyVal || value;\n }\n vals[i] = anyVal;\n }\n return result;\n }\n\n squaredDifference(a: Tensor, b: Tensor): Tensor {\n assertNotComplex([a, b], 'squaredDifference');\n\n return this.broadcastedBinaryOp(a, b, a.dtype, (aVal, bVal) => {\n const diff = aVal - bVal;\n return diff * diff;\n });\n }\n\n ceil(x: T): T {\n assertNotComplex(x, 'ceil');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n newValues[i] = Math.ceil(values[i]);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n floor(x: T): T {\n assertNotComplex(x, 'floor');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n newValues[i] = Math.floor(values[i]);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n sign(x: T): T {\n assertNotComplex(x, 'x');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n if (values[i] < 0) {\n newValues[i] = -1;\n } else if (values[i] > 0) {\n newValues[i] = 1;\n } else {\n newValues[i] = 0;\n }\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n isNaN(x: T): T {\n assertNotComplex(x, 'x');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Uint8Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n if (Number.isNaN(values[i])) {\n newValues[i] = 1;\n }\n }\n return this.makeOutput(newValues, x.shape, 'bool');\n }\n\n isInf(x: T): T {\n assertNotComplex(x, 'x');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Uint8Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n if (Math.abs(values[i]) === Infinity) {\n newValues[i] = 1;\n }\n }\n return this.makeOutput(newValues, x.shape, 'bool');\n }\n\n isFinite(x: T): T {\n assertNotComplex(x, 'x');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Uint8Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n if (Number.isFinite(values[i])) {\n newValues[i] = 1;\n }\n }\n return this.makeOutput(newValues, x.shape, 'bool');\n }\n\n round(x: T): T {\n assertNotComplex(x, 'round');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n // The algorithm is based on banker's rounding.\n const base = Math.floor(values[i]);\n if (values[i] - base < 0.5) {\n newValues[i] = Math.floor(values[i]);\n } else if (values[i] - base > 0.5) {\n newValues[i] = Math.ceil(values[i]);\n } else {\n if (base % 2.0 === 0.0) {\n newValues[i] = base;\n } else {\n newValues[i] = base + 1.0;\n }\n }\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n exp(x: T): T {\n assertNotComplex(x, 'exp');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n newValues[i] = Math.exp(values[i]);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n expm1(x: T): T {\n assertNotComplex(x, 'expm1');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n newValues[i] = Math.expm1(values[i]);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n log(x: T): T {\n assertNotComplex(x, 'log');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n const value = values[i];\n newValues[i] = Math.log(value);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n log1p(x: T): T {\n assertNotComplex(x, 'log1p');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n const value = values[i];\n newValues[i] = Math.log1p(value);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n sqrt(x: T): T {\n assertNotComplex(x, 'sqrt');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n const value = values[i];\n newValues[i] = Math.sqrt(value);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n rsqrt(x: T): T {\n assertNotComplex(x, 'rsqrt');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n const value = values[i];\n newValues[i] = 1 / Math.sqrt(value);\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n reciprocal(x: T): T {\n assertNotComplex(x, 'reciprocal');\n\n const values = this.readSync(x.dataId) as TypedArray;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n newValues[i] = 1 / values[i];\n }\n return this.makeOutput(newValues, x.shape, 'float32');\n }\n\n linear(x: T): T {\n return x;\n }\n\n relu(x: T): T {\n assertNotComplex(x, 'relu');\n\n const res = ops.zeros(x.shape, x.dtype);\n const resVals = this.readSync(res.dataId) as TypedArray;\n const inVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < inVals.length; ++i) {\n resVals[i] = Math.max(0, inVals[i]);\n }\n return res as T;\n }\n\n relu6(x: T): T {\n assertNotComplex(x, 'relu');\n\n const res = ops.zeros(x.shape, x.dtype);\n const resVals = this.readSync(res.dataId) as TypedArray;\n const inVals = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < inVals.length; ++i) {\n resVals[i] = Math.min(Math.max(0, inVals[i]), 6);\n }\n return res as T;\n }\n\n prelu(x: T, a: T): T {\n assertNotComplex([x, a], 'prelu');\n\n return this.broadcastedBinaryOp(\n x, a, x.dtype,\n (xValue, aValue) => xValue < 0 ? aValue * xValue : xValue) as T;\n }\n\n elu(x: T): T {\n assertNotComplex(x, 'elu');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n const v = values[i];\n if (v >= 0) {\n resultValues[i] = v;\n } else {\n resultValues[i] = (Math.exp(v) - 1);\n }\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n eluDer(dy: T, y: T): T {\n assertNotComplex([dy, y], 'eluDer');\n\n const resultValues = new Float32Array(y.size);\n const values = this.readSync(y.dataId) as TypedArray;\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n const v = values[i];\n if (v >= 1) {\n resultValues[i] = dyValues[i];\n } else {\n resultValues[i] = dyValues[i] * (v + 1);\n }\n }\n return this.makeOutput(resultValues, y.shape, 'float32');\n }\n\n selu(x: T): T {\n assertNotComplex(x, 'selu');\n\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n const scaleAlpha = selu_util.SELU_SCALEALPHA;\n const scale = selu_util.SELU_SCALE;\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n const v = values[i];\n if (v >= 0) {\n resultValues[i] = scale * v;\n } else {\n resultValues[i] = scaleAlpha * (Math.exp(v) - 1);\n }\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n clip(x: T, min: number, max: number): T {\n assertNotComplex(x, 'clip');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n const v = values[i];\n resultValues[i] = v > max ? max : (v < min ? min : v);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n abs(x: T): T {\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.abs(values[i]);\n }\n\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n complexAbs(x: T): T {\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n\n for (let i = 0; i < x.size; ++i) {\n const real = values[i * 2];\n const imag = values[i * 2 + 1];\n resultValues[i] = Math.hypot(real, imag);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n int(x: T): T {\n assertNotComplex(x, 'int');\n\n const resultValues = new Int32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = values[i];\n }\n return this.makeOutput(resultValues, x.shape, 'int32');\n }\n\n sigmoid(x: T): T {\n assertNotComplex(x, 'sigmoid');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = 1 / (1 + Math.exp(-values[i]));\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n softplus(x: T): T {\n assertNotComplex(x, 'softplus');\n\n // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX\n\n // epsilon is the difference between 1.0 and the next representable float.\n // For a single precision 32 bit float this should be 2^-23, see:\n // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm\n const epsilon = 1.1920928955078125e-7;\n const threshold = Math.log(epsilon) + 2.0;\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n\n for (let i = 0; i < values.length; ++i) {\n // Value above which exp(x) may overflow, but softplus(x) == x\n // is within machine epsilon.\n const tooLarge = values[i] > -threshold;\n\n // Value below which exp(x) may underflow, but softplus(x) == exp(x)\n // is within machine epsilon.\n const tooSmall = values[i] < threshold;\n\n const expX = Math.exp(values[i]);\n let result;\n\n if (tooSmall) {\n result = expX;\n } else if (tooLarge) {\n result = values[i];\n } else {\n result = Math.log(1.0 + expX);\n }\n resultValues[i] = result;\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n sin(x: T): T {\n assertNotComplex(x, 'sin');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.sin(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n cos(x: T): T {\n assertNotComplex(x, 'cos');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.cos(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n tan(x: T): T {\n assertNotComplex(x, 'tan');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.tan(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n asin(x: T): T {\n assertNotComplex(x, 'asin');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.asin(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n acos(x: T): T {\n assertNotComplex(x, 'acos');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.acos(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n atan(x: T): T {\n assertNotComplex(x, 'atan');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.atan(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n atan2(a: T, b: T): T {\n assertNotComplex([a, b], 'atan2');\n\n return this.broadcastedBinaryOp(\n a, b, a.dtype, (aValue, bValue) => Math.atan2(aValue, bValue)) as\n T;\n }\n\n sinh(x: T): T {\n assertNotComplex(x, 'sinh');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.sinh(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n cosh(x: T): T {\n assertNotComplex(x, 'cosh');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.cosh(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n tanh(x: T): T {\n assertNotComplex(x, 'tanh');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = util.tanh(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n asinh(x: T): T {\n assertNotComplex(x, 'asinh');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.asinh(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n acosh(x: T): T {\n assertNotComplex(x, 'acosh');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.acosh(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n atanh(x: T): T {\n assertNotComplex(x, 'atanh');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n resultValues[i] = Math.atanh(values[i]);\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n erf(x: T): T {\n assertNotComplex(x, 'erf');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n const p = erf_util.ERF_P;\n const a1 = erf_util.ERF_A1;\n const a2 = erf_util.ERF_A2;\n const a3 = erf_util.ERF_A3;\n const a4 = erf_util.ERF_A4;\n const a5 = erf_util.ERF_A5;\n for (let i = 0; i < values.length; ++i) {\n const sign = Math.sign(values[i]);\n const v = Math.abs(values[i]);\n const t = 1.0 / (1.0 + p * v);\n resultValues[i] = sign *\n (1.0 -\n (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *\n Math.exp(-v * v));\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n step(x: T, alpha = 0): T {\n assertNotComplex(x, 'step');\n\n const resultValues = new Float32Array(x.size);\n const values = this.readSync(x.dataId) as TypedArray;\n for (let i = 0; i < values.length; ++i) {\n const value = values[i];\n if (isNaN(value)) {\n resultValues[i] = NaN;\n } else {\n resultValues[i] = value > 0 ? 1 : alpha;\n }\n }\n return this.makeOutput(resultValues, x.shape, 'float32');\n }\n\n fusedConv2d(\n {input, filter, convInfo, bias, activation, preluActivationWeights}:\n FusedConv2DConfig): Tensor4D {\n let result = this.conv2d(input, filter, convInfo);\n\n if (bias) {\n result = this.add(result, bias) as Tensor4D;\n }\n if (activation) {\n result =\n mapActivation(this, result, activation, preluActivationWeights) as\n Tensor4D;\n }\n return result;\n }\n\n conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n assertNotComplex([x, filter], 'conv2d');\n\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const padLeft = convInfo.padInfo.left;\n const padTop = convInfo.padInfo.top;\n const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n\n const y = ops.buffer(convInfo.outShape, x.dtype as 'float32');\n\n const xBatchStride = x.strides[0];\n const xRowStride = isChannelsLast ? x.strides[1] : x.strides[2];\n const xColStride = isChannelsLast ? x.strides[2] : 1;\n const xChannelStride = isChannelsLast ? 1 : x.strides[1];\n const yBatchStride = y.strides[0];\n const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];\n const yColStride = isChannelsLast ? y.strides[2] : 1;\n const yChannelStride = isChannelsLast ? 1 : y.strides[1];\n\n const xVals = this.readSync(x.dataId) as TypedArray;\n const wVals = this.readSync(filter.dataId) as TypedArray;\n const yVals = y.values;\n\n for (let b = 0; b < convInfo.batchSize; ++b) {\n const xOffset1 = b * xBatchStride;\n const yOffset1 = b * yBatchStride;\n for (let yR = 0; yR < convInfo.outHeight; ++yR) {\n const yOffset2 = yOffset1 + yR * yRowStride;\n const xRCorner = yR * convInfo.strideHeight - padTop;\n for (let wR = 0; wR < filterHeight; wR++) {\n const xR = xRCorner + wR * dilationHeight;\n if (xR < 0 || xR >= convInfo.inHeight) {\n continue;\n }\n const wOffset1 = wR * filter.strides[0];\n const xOffset2 = xOffset1 + xR * xRowStride;\n for (let yC = 0; yC < convInfo.outWidth; ++yC) {\n const yOffset3 = yOffset2 + yC * yColStride;\n const xCCorner = yC * convInfo.strideWidth - padLeft;\n for (let wC = 0; wC < filterWidth; wC++) {\n const xC = xCCorner + wC * dilationWidth;\n if (xC < 0 || xC >= convInfo.inWidth) {\n continue;\n }\n const wOffset2 = wOffset1 + wC * filter.strides[1];\n const xOffset3 = xOffset2 + xC * xColStride;\n let wOffset3 = wOffset2;\n for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n const xVal = xVals[xOffset3 + d1 * xChannelStride];\n for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n yVals[yOffset3 + d2 * yChannelStride] +=\n xVal * wVals[wOffset3 + d2];\n }\n wOffset3 += convInfo.outChannels;\n }\n }\n }\n }\n }\n }\n return y.toTensor() as Tensor4D;\n }\n\n conv3d(x: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const filterDepth = convInfo.filterDepth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const padFront = convInfo.padInfo.front;\n const padLeft = convInfo.padInfo.left;\n const padTop = convInfo.padInfo.top;\n const y = ops.buffer(convInfo.outShape, x.dtype as 'float32');\n\n const xVals = this.readSync(x.dataId) as TypedArray;\n const wVals = this.readSync(filter.dataId) as TypedArray;\n const yVals = y.values;\n\n for (let b = 0; b < convInfo.batchSize; ++b) {\n const xOffset1 = b * x.strides[0];\n const yOffset1 = b * y.strides[0];\n for (let yF = 0; yF < convInfo.outDepth; ++yF) {\n const yOffset2 = yOffset1 + yF * y.strides[1];\n const xFCorner = yF * convInfo.strideDepth - padFront;\n for (let wF = 0; wF < filterDepth; wF++) {\n const xF = xFCorner + wF * dilationDepth;\n if (xF < 0 || xF >= convInfo.inDepth) {\n continue;\n }\n const wOffset1 = wF * filter.strides[0];\n const xOffset2 = xOffset1 + xF * x.strides[1];\n\n for (let yR = 0; yR < convInfo.outHeight; ++yR) {\n const yOffset3 = yOffset2 + yR * y.strides[2];\n const xRCorner = yR * convInfo.strideHeight - padTop;\n for (let wR = 0; wR < filterHeight; wR++) {\n const xR = xRCorner + wR * dilationHeight;\n if (xR < 0 || xR >= convInfo.inHeight) {\n continue;\n }\n const wOffset2 = wOffset1 + wR * filter.strides[1];\n const xOffset3 = xOffset2 + xR * x.strides[2];\n for (let yC = 0; yC < convInfo.outWidth; ++yC) {\n const yOffset4 = yOffset3 + yC * convInfo.outChannels;\n const xCCorner = yC * convInfo.strideWidth - padLeft;\n for (let wC = 0; wC < filterWidth; wC++) {\n const xC = xCCorner + wC * dilationWidth;\n if (xC < 0 || xC >= convInfo.inWidth) {\n continue;\n }\n const wOffset3 = wOffset2 + wC * filter.strides[2];\n const xOffset4 = xOffset3 + xC * convInfo.inChannels;\n let wOffset4 = wOffset3;\n for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n const xVal = xVals[xOffset4 + d1];\n for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];\n }\n wOffset4 += convInfo.outChannels;\n }\n }\n }\n }\n }\n }\n }\n }\n return y.toTensor();\n }\n\n conv2dDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n assertNotComplex([dy, filter], 'conv2dDerInput');\n\n const dx = ops.buffer(convInfo.inShape, 'float32');\n const dxValues = dx.values;\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n const fltValues = this.readSync(filter.dataId) as TypedArray;\n const [fltS0, fltS1, fltS2] = filter.strides;\n const {\n batchSize,\n filterHeight,\n filterWidth,\n inChannels,\n inHeight,\n inWidth,\n outChannels,\n outHeight,\n outWidth,\n strideHeight,\n strideWidth,\n dataFormat\n } = convInfo;\n const topPad = filterHeight - 1 - convInfo.padInfo.top;\n const leftPad = filterWidth - 1 - convInfo.padInfo.left;\n\n const isChannelsLast = dataFormat === 'channelsLast';\n const xBatchStride = dx.strides[0];\n const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];\n const xColStride = isChannelsLast ? dx.strides[2] : 1;\n const xChannelStride = isChannelsLast ? 1 : dx.strides[1];\n const yBatchStride = dy.strides[0];\n const yRowStride = isChannelsLast ? dy.strides[1] : dy.strides[2];\n const yColStride = isChannelsLast ? dy.strides[2] : 1;\n const yChannelStride = isChannelsLast ? 1 : dy.strides[1];\n\n for (let b = 0; b < batchSize; ++b) {\n for (let d1 = 0; d1 < inChannels; ++d1) {\n for (let xR = 0; xR < inHeight; ++xR) {\n const xRCorner = xR - topPad;\n const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));\n const yRMax =\n Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);\n\n for (let xC = 0; xC < inWidth; ++xC) {\n const xCCorner = xC - leftPad;\n const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));\n const yCMax =\n Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);\n\n let dotProd = 0;\n for (let yR = xRMin; yR < yRMax; ++yR) {\n const wR = yR * strideHeight - xRCorner;\n\n for (let yC = xCMin; yC < yCMax; ++yC) {\n const wC = yC * strideWidth - xCCorner;\n const dyOffset =\n yBatchStride * b + yRowStride * yR + yColStride * yC;\n const fltOffset = fltS0 * (filterHeight - 1 - wR) +\n fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;\n\n for (let d2 = 0; d2 < outChannels; ++d2) {\n const pixel = dyValues[dyOffset + yChannelStride * d2];\n const weight = fltValues[fltOffset + d2];\n dotProd += pixel * weight;\n }\n }\n }\n const dxOffset = xBatchStride * b + xRowStride * xR +\n xColStride * xC + xChannelStride * d1;\n dxValues[dxOffset] = dotProd;\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n conv3dDerInput(dy: Tensor5D, filter: Tensor5D, convInfo: Conv3DInfo):\n Tensor5D {\n const dx = ops.buffer(convInfo.inShape, 'float32');\n const dxValues = dx.values;\n const [dxS0, dxS1, dxS2, dxS3] = dx.strides;\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n const [dyS0, dyS1, dyS2, dyS3] = dy.strides;\n const fltValues = this.readSync(filter.dataId) as TypedArray;\n const [fltS0, fltS1, fltS2, fltS3] = filter.strides;\n const {\n batchSize,\n filterDepth,\n filterHeight,\n filterWidth,\n inChannels,\n inDepth,\n inHeight,\n inWidth,\n outChannels,\n outDepth,\n outHeight,\n outWidth,\n strideDepth,\n strideHeight,\n strideWidth\n } = convInfo;\n const frontPad = filterDepth - 1 - convInfo.padInfo.front;\n const topPad = filterHeight - 1 - convInfo.padInfo.top;\n const leftPad = filterWidth - 1 - convInfo.padInfo.left;\n\n for (let b = 0; b < batchSize; ++b) {\n for (let d1 = 0; d1 < inChannels; ++d1) {\n // Frames of depth\n for (let xF = 0; xF < inDepth; ++xF) {\n const xFCorner = xF - frontPad;\n const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));\n const yFMax =\n Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);\n\n // Rows as per standard 2d matrix notation\n for (let xR = 0; xR < inHeight; ++xR) {\n const xRCorner = xR - topPad;\n const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));\n const yRMax =\n Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);\n // Columns as per standard 2d matrix notation\n for (let xC = 0; xC < inWidth; ++xC) {\n const xCCorner = xC - leftPad;\n const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));\n const yCMax =\n Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);\n\n let dotProd = 0;\n for (let yF = xFMin; yF < yFMax; ++yF) {\n const wF = yF * strideDepth - xFCorner;\n\n for (let yR = xRMin; yR < yRMax; ++yR) {\n const wR = yR * strideHeight - xRCorner;\n\n for (let yC = xCMin; yC < yCMax; ++yC) {\n const wC = yC * strideWidth - xCCorner;\n const dyOffset =\n dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;\n const fltOffset = fltS0 * (filterDepth - 1 - wF) +\n fltS1 * (filterHeight - 1 - wR) +\n fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;\n\n for (let d2 = 0; d2 < outChannels; ++d2) {\n const pixel = dyValues[dyOffset + d2];\n const weight = fltValues[fltOffset + d2];\n dotProd += pixel * weight;\n }\n }\n }\n }\n dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =\n dotProd;\n }\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n conv2dDerFilter(x: Tensor4D, dy: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n assertNotComplex([x, dy], 'conv2dDerFilter');\n\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n const dW = ops.buffer(convInfo.filterShape, 'float32');\n\n const leftPad = convInfo.padInfo.left;\n const topPad = convInfo.padInfo.top;\n const xBuf = this.bufferSync(x);\n const dyBuf = this.bufferSync(dy);\n for (let wR = 0; wR < filterHeight; ++wR) {\n const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));\n const yRMax = Math.min(\n convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);\n\n for (let wC = 0; wC < filterWidth; ++wC) {\n const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));\n const yCMax = Math.min(\n convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);\n\n for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n // Need to convolve.\n let dotProd = 0;\n for (let b = 0; b < convInfo.batchSize; ++b) {\n for (let yR = yRMin; yR < yRMax; ++yR) {\n const xR = wR + yR * strideHeight - topPad;\n for (let yC = yCMin; yC < yCMax; ++yC) {\n const xC = wC + yC * strideWidth - leftPad;\n if (isChannelsLast) {\n dotProd +=\n xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);\n } else {\n dotProd +=\n xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC);\n }\n }\n }\n }\n dW.set(dotProd, wR, wC, d1, d2);\n }\n }\n }\n }\n return dW.toTensor();\n }\n\n conv3dDerFilter(x: Tensor5D, dy: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const filterDepth = convInfo.filterDepth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n\n const dw = ops.buffer(convInfo.filterShape, 'float32');\n const dwValues = dw.values;\n const [dwS0, dwS1, dwS2, dwS3] = dw.strides;\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n const [dyS0, dyS1, dyS2, dyS3] = dy.strides;\n const xValues = this.readSync(x.dataId) as TypedArray;\n const [xS0, xS1, xS2, xS3] = x.strides;\n\n const frontPad = convInfo.padInfo.front;\n const leftPad = convInfo.padInfo.left;\n const topPad = convInfo.padInfo.top;\n\n for (let wF = 0; wF < filterDepth; ++wF) {\n const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));\n const yFMax = Math.min(\n convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);\n const wOffset1 = wF * dwS0;\n\n for (let wR = 0; wR < filterHeight; ++wR) {\n const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));\n const yRMax = Math.min(\n convInfo.outHeight,\n (convInfo.inHeight + topPad - wR) / strideHeight);\n const wOffset2 = wR * dwS1 + wOffset1;\n\n for (let wC = 0; wC < filterWidth; ++wC) {\n const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));\n const yCMax = Math.min(\n convInfo.outWidth,\n (convInfo.inWidth + leftPad - wC) / strideWidth);\n const wOffset3 = wC * dwS2 + wOffset2;\n\n for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n const wOffset4 = d1 * dwS3 + wOffset3;\n\n for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n let dotProd = 0;\n for (let b = 0; b < convInfo.batchSize; ++b) {\n const xOffset1 = b * xS0;\n const yOffset1 = b * dyS0;\n\n for (let yF = yFMin; yF < yFMax; ++yF) {\n const xF = wF + yF * strideDepth - frontPad;\n const xOffset2 = xF * xS1 + xOffset1;\n const yOffset2 = yF * dyS1 + yOffset1;\n\n for (let yR = yRMin; yR < yRMax; ++yR) {\n const xR = wR + yR * strideHeight - topPad;\n const xOffset3 = xR * xS2 + xOffset2;\n const yOffset3 = yR * dyS2 + yOffset2;\n\n for (let yC = yCMin; yC < yCMax; ++yC) {\n const xC = wC + yC * strideWidth - leftPad;\n const xOffset4 = xC * xS3 + xOffset3;\n const yOffset4 = yC * dyS3 + yOffset3;\n\n dotProd +=\n xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];\n }\n }\n }\n }\n dwValues[wOffset4 + d2] = dotProd;\n }\n }\n }\n }\n }\n return dw.toTensor();\n }\n\n fusedDepthwiseConv2D(\n {input, filter, convInfo, bias, activation, preluActivationWeights}:\n FusedConv2DConfig): Tensor4D {\n let result = this.depthwiseConv2D(input, filter, convInfo);\n\n if (bias) {\n result = this.add(result, bias) as Tensor4D;\n }\n if (activation) {\n result =\n mapActivation(this, result, activation, preluActivationWeights) as\n Tensor4D;\n }\n return result;\n }\n\n depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n assertNotComplex([x, filter], 'depthwiseConv2D');\n\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const padLeft = convInfo.padInfo.left;\n const padTop = convInfo.padInfo.top;\n const chMul = convInfo.outChannels / convInfo.inChannels;\n const y = ops.buffer(convInfo.outShape, x.dtype as 'float32');\n const xVals = this.readSync(x.dataId) as TypedArray;\n const wVals = this.readSync(filter.dataId) as TypedArray;\n const yVals = y.values;\n\n for (let b = 0; b < convInfo.batchSize; ++b) {\n const xOffset1 = b * x.strides[0];\n const yOffset1 = b * y.strides[0];\n for (let yR = 0; yR < convInfo.outHeight; ++yR) {\n const yOffset2 = yOffset1 + yR * y.strides[1];\n const xRCorner = yR * convInfo.strideHeight - padLeft;\n for (let wR = 0; wR < filterHeight; ++wR) {\n const xR = xRCorner + wR * dilationHeight;\n if (xR < 0 || xR >= convInfo.inHeight) {\n continue;\n }\n const wOffset1 = wR * filter.strides[0];\n const xOffset2 = xOffset1 + xR * x.strides[1];\n for (let yC = 0; yC < convInfo.outWidth; ++yC) {\n const yOffset3 = yOffset2 + yC * y.strides[2];\n const xCCorner = yC * convInfo.strideWidth - padTop;\n for (let wC = 0; wC < filterWidth; ++wC) {\n const xC = xCCorner + wC * dilationWidth;\n if (xC < 0 || xC >= convInfo.inWidth) {\n continue;\n }\n const wOffset2 = wOffset1 + wC * filter.strides[1];\n const xOffset3 = xOffset2 + xC * convInfo.inChannels;\n let yOffset4 = yOffset3;\n let wOffset3 = wOffset2;\n for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n const xVal = xVals[xOffset3 + d1];\n for (let q = 0; q < chMul; ++q) {\n yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];\n }\n yOffset4 += chMul;\n wOffset3 += chMul;\n }\n }\n }\n }\n }\n }\n\n return y.toTensor() as Tensor4D;\n }\n\n depthwiseConv2DDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n assertNotComplex([dy, filter], 'depthwiseConv2DDerInput');\n\n const dx = ops.buffer(convInfo.inShape, 'float32');\n const dxValues = dx.values;\n const [dxS0, dxS1, dxS2] = dx.strides;\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n const [dyS0, dyS1, dyS2] = dy.strides;\n const fltValues = this.readSync(filter.dataId) as TypedArray;\n const [fltS0, fltS1, fltS2] = filter.strides;\n const {\n batchSize,\n filterHeight,\n filterWidth,\n inChannels,\n inHeight,\n inWidth,\n outChannels,\n outHeight,\n outWidth,\n strideHeight,\n strideWidth\n } = convInfo;\n const topPad = filterHeight - 1 - convInfo.padInfo.top;\n const leftPad = filterWidth - 1 - convInfo.padInfo.left;\n const chMul = outChannels / inChannels;\n\n for (let b = 0; b < batchSize; ++b) {\n for (let d1 = 0; d1 < inChannels; ++d1) {\n for (let xR = 0; xR < inHeight; ++xR) {\n const xRCorner = xR - topPad;\n const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));\n const yRMax =\n Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);\n\n for (let xC = 0; xC < inWidth; ++xC) {\n const xCCorner = xC - leftPad;\n const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));\n const yCMax =\n Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);\n\n let dotProd = 0;\n for (let yR = xRMin; yR < yRMax; ++yR) {\n const wR = yR * strideHeight - xRCorner;\n\n for (let yC = xCMin; yC < yCMax; ++yC) {\n const wC = yC * strideWidth - xCCorner;\n const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;\n const fltOffset = fltS0 * (filterHeight - 1 - wR) +\n fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;\n\n for (let dm = 0; dm < chMul; ++dm) {\n const d2 = d1 * chMul + dm;\n const pixel = dyValues[dyOffset + d2];\n const weight = fltValues[fltOffset + dm];\n dotProd += pixel * weight;\n }\n }\n }\n dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n depthwiseConv2DDerFilter(x: Tensor4D, dy: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n assertNotComplex([x, dy], 'depthwiseConv2DDerFilter');\n\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const dW = ops.buffer(convInfo.filterShape, 'float32');\n\n const leftPad = convInfo.padInfo.left;\n const topPad = convInfo.padInfo.top;\n const chMul = convInfo.outChannels / convInfo.inChannels;\n\n const xBuf = this.bufferSync(x);\n const dyBuf = this.bufferSync(dy);\n for (let wR = 0; wR < filterHeight; ++wR) {\n const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));\n const yRMax = Math.min(\n convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);\n\n for (let wC = 0; wC < filterWidth; ++wC) {\n const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));\n const yCMax = Math.min(\n convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);\n\n for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n const d1 = Math.trunc(d2 / chMul);\n const dm = d2 % chMul;\n\n let dotProd = 0;\n for (let b = 0; b < convInfo.batchSize; ++b) {\n for (let yR = yRMin; yR < yRMax; ++yR) {\n const xR = wR + yR * strideHeight - topPad;\n for (let yC = yCMin; yC < yCMax; ++yC) {\n const xC = wC + yC * strideWidth - leftPad;\n dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);\n }\n }\n }\n dW.set(dotProd, wR, wC, d1, dm);\n }\n }\n }\n return dW.toTensor();\n }\n\n tile(x: T, reps: number[]): T {\n assertNotComplex(x, 'tile');\n return tile(this.bufferSync(x), reps) as T;\n }\n\n pad(\n x: T, paddings: Array<[number, number]>, constantValue: number): T {\n assertNotComplex(x, 'pad');\n\n const outShape = paddings.map(\n (p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);\n const start = paddings.map(p => p[0]);\n const xBuffer = this.bufferSync(x);\n const buffer = ops.buffer(outShape, x.dtype as 'float32');\n if (constantValue !== 0) {\n buffer.values.fill(constantValue);\n }\n\n for (let i = 0; i < x.size; i++) {\n const coords = xBuffer.indexToLoc(i);\n const outCoords = coords.map((c, i) => c + start[i]);\n buffer.set(xBuffer.get(...coords), ...outCoords);\n }\n return buffer.toTensor() as T;\n }\n\n gather(x: T, indices: Tensor1D, axis: number): T {\n assertNotComplex([x, indices], 'gather');\n\n const newShape: number[] = x.shape.slice();\n const indicesValues = this.readSync(indices.dataId) as TypedArray;\n newShape[axis] = indicesValues.length;\n const result = buffer(newShape, x.dtype);\n const xBuf = this.bufferSync(x);\n\n for (let i = 0; i < result.size; ++i) {\n const newLoc = result.indexToLoc(i);\n\n const originalLoc: number[] = newLoc.slice();\n originalLoc[axis] = indicesValues[newLoc[axis]];\n\n const originalIndex = xBuf.locToIndex(originalLoc);\n result.values[i] = xBuf.values[originalIndex];\n }\n return result.toTensor() as T;\n }\n\n batchToSpaceND(\n x: T, blockShape: number[], crops: number[][]): T {\n assertNotComplex([x], 'batchToSpaceND');\n\n const prod = blockShape.reduce((a, b) => a * b);\n\n const reshaped = array_ops_util.getReshaped(x.shape, blockShape, prod);\n const permuted =\n array_ops_util.getPermuted(reshaped.length, blockShape.length);\n const reshapedPermuted =\n array_ops_util.getReshapedPermuted(x.shape, blockShape, prod);\n const sliceBeginCoords =\n array_ops_util.getSliceBeginCoords(crops, blockShape.length);\n const sliceSize =\n array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);\n\n return transpose(x.reshape(reshaped), permuted)\n .reshape(reshapedPermuted)\n .slice(sliceBeginCoords, sliceSize) as T;\n }\n\n spaceToBatchND(\n x: T, blockShape: number[], paddings: Array<[number, number]>): T {\n assertNotComplex([x], 'spaceToBatchND');\n\n const prod = blockShape.reduce((a, b) => a * b);\n\n const completePaddings: Array<[number, number]> = [[0, 0]];\n completePaddings.push(...paddings);\n for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {\n completePaddings.push([0, 0]);\n }\n\n const paddedX = x.pad(completePaddings);\n\n const reshapedPaddedShape =\n array_ops_util.getReshaped(paddedX.shape, blockShape, prod, false);\n const permutedReshapedPaddedPermutation = array_ops_util.getPermuted(\n reshapedPaddedShape.length, blockShape.length, false);\n const flattenShape = array_ops_util.getReshapedPermuted(\n paddedX.shape, blockShape, prod, false);\n\n return transpose(\n paddedX.reshape(reshapedPaddedShape),\n permutedReshapedPaddedPermutation)\n .reshape(flattenShape) as T;\n }\n\n maxPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n assertNotComplex(x, 'maxPool');\n const xValues = this.readSync(x.dataId) as TypedArray;\n return pool(xValues, x.shape, x.dtype, x.strides, convInfo, 'max')\n .toTensor() as Tensor4D;\n }\n\n maxPoolBackprop(dy: Tensor4D, x: Tensor4D, y: Tensor4D, convInfo: Conv2DInfo):\n Tensor4D {\n assertNotComplex([x, y], 'maxPoolBackprop');\n\n const xValues = this.readSync(x.dataId) as TypedArray;\n const maxPosBuf = buffer(\n convInfo.outShape, x.dtype,\n maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const dx = ops.buffer(x.shape, 'float32');\n\n const dyBuf = this.bufferSync(dy);\n\n for (let b = 0; b < convInfo.batchSize; ++b) {\n for (let d = 0; d < convInfo.inChannels; ++d) {\n for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {\n for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {\n // Shader code begins.\n const dyRCorner = dxR - padTop;\n const dyCCorner = dxC - padLeft;\n let dotProd = 0;\n for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {\n const dyR = (dyRCorner + wR) / strideHeight;\n if (dyR < 0 || dyR >= convInfo.outHeight ||\n Math.floor(dyR) !== dyR) {\n continue;\n }\n for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {\n const dyC = (dyCCorner + wC) / strideWidth;\n if (dyC < 0 || dyC >= convInfo.outWidth ||\n Math.floor(dyC) !== dyC) {\n continue;\n }\n const maxPos = effectiveFilterHeight * effectiveFilterWidth -\n 1 - (maxPosBuf.get(b, dyR, dyC, d) as number);\n const curPos = wR * effectiveFilterWidth + wC;\n\n const mask = maxPos === curPos ? 1 : 0;\n if (mask === 0) {\n continue;\n }\n\n const pixel = dyBuf.get(b, dyR, dyC, d);\n dotProd += pixel * mask;\n }\n }\n dx.set(dotProd, b, dxR, dxC, d);\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n assertNotComplex([dy, x], 'avgPoolBackprop');\n\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const dx = ops.buffer(x.shape, 'float32');\n\n const avgMultiplier = 1 / (filterHeight * filterWidth);\n\n const dyBuf = this.bufferSync(dy);\n\n for (let b = 0; b < convInfo.batchSize; ++b) {\n for (let d = 0; d < convInfo.inChannels; ++d) {\n for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {\n for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {\n // Shader code begins.\n const dyRCorner = dxR - padTop;\n const dyCCorner = dxC - padLeft;\n let dotProd = 0;\n for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {\n const dyR = (dyRCorner + wR) / strideHeight;\n if (dyR < 0 || dyR >= convInfo.outHeight ||\n Math.floor(dyR) !== dyR) {\n continue;\n }\n for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {\n const dyC = (dyCCorner + wC) / strideWidth;\n if (dyC < 0 || dyC >= convInfo.outWidth ||\n Math.floor(dyC) !== dyC) {\n continue;\n }\n\n const pixel = dyBuf.get(b, dyR, dyC, d);\n dotProd += pixel;\n }\n }\n dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n private pool3d(x: Tensor5D, convInfo: Conv3DInfo, poolType: 'max'|'avg'):\n Tensor5D {\n assertNotComplex(x, 'pool3d');\n\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padFront = convInfo.padInfo.front;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n\n const initialValue =\n (poolType === 'max' ? Number.NEGATIVE_INFINITY :\n Number.POSITIVE_INFINITY);\n\n const xValues = this.readSync(x.dataId) as TypedArray;\n const output = ops.buffer(convInfo.outShape, x.dtype);\n const outputVals = output.values;\n\n const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *\n convInfo.outShape[3] * convInfo.outShape[4];\n const outputDepthStrides =\n convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];\n const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];\n const outputColStrides = convInfo.outShape[4];\n\n for (let batch = 0; batch < convInfo.batchSize; ++batch) {\n const outputBatchOffset = batch * outputBatchStrides;\n const inputBatchOffset = batch * x.strides[0];\n for (let channel = 0; channel < convInfo.inChannels; ++channel) {\n for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {\n const xDepthCorner = yDepth * strideDepth - padFront;\n let xDepthMin = xDepthCorner;\n while (xDepthMin < 0) {\n xDepthMin += dilationDepth;\n }\n const xDepthMax =\n Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);\n const outputDepthOffset =\n outputBatchOffset + yDepth * outputDepthStrides;\n for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {\n const xRowCorner = yRow * strideHeight - padTop;\n let xRowMin = xRowCorner;\n while (xRowMin < 0) {\n xRowMin += dilationHeight;\n }\n const xRowMax =\n Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);\n const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;\n for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {\n const xColCorner = yCol * strideWidth - padLeft;\n let xColMin = xColCorner;\n while (xColMin < 0) {\n xColMin += dilationWidth;\n }\n const xColMax =\n Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);\n // Shader code begins\n const outputColOffset = outputRowOffset + yCol * outputColStrides;\n let minMaxValue = initialValue;\n let avgValue = 0;\n let count = 0;\n for (let xDepth = xDepthMin; xDepth < xDepthMax;\n xDepth += dilationDepth) {\n const xDepthOffset = inputBatchOffset + xDepth * x.strides[1];\n for (let xRow = xRowMin; xRow < xRowMax;\n xRow += dilationHeight) {\n const xRowOffset = xDepthOffset + xRow * x.strides[2];\n for (let xCol = xColMin; xCol < xColMax;\n xCol += dilationWidth) {\n const xColOffset = xRowOffset + xCol * x.strides[3];\n const pixel = xValues[xColOffset + channel];\n if ((poolType === 'max' && pixel > minMaxValue)) {\n minMaxValue = pixel;\n } else if (poolType === 'avg') {\n avgValue += pixel;\n count++;\n }\n if (isNaN(minMaxValue)) {\n break;\n }\n }\n if (isNaN(minMaxValue)) {\n break;\n }\n }\n if (isNaN(minMaxValue)) {\n break;\n }\n }\n const outputOffset = outputColOffset + channel;\n outputVals[outputOffset] =\n poolType === 'avg' ? avgValue / count : minMaxValue;\n }\n }\n }\n }\n }\n return output.toTensor() as Tensor5D;\n }\n\n avgPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n assertNotComplex(x, 'avgPool3d');\n\n return this.pool3d(x, convInfo, 'avg').toFloat();\n }\n\n avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n assertNotComplex([dy, x], 'avgPool3dBackprop');\n\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const filterDepth = convInfo.filterDepth;\n const filterHeight = convInfo.filterHeight;\n const filterWidth = convInfo.filterWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const dx = ops.buffer(x.shape, 'float32');\n\n const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);\n\n const dyBuf = this.bufferSync(dy);\n\n for (let batch = 0; batch < convInfo.batchSize; ++batch) {\n for (let channel = 0; channel < convInfo.inChannels; ++channel) {\n for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {\n for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {\n for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {\n // Shader code begins.\n const dyDepthCorner = dxDepth - padFront;\n const dyRowCorner = dxRow - padTop;\n const dyColCorner = dxCol - padLeft;\n let dotProd = 0;\n for (let wDepth = 0; wDepth < effectiveFilterDepth;\n wDepth += dilationDepth) {\n const dyDepth = (dyDepthCorner + wDepth) / strideDepth;\n if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||\n Math.floor(dyDepth) !== dyDepth) {\n continue;\n }\n for (let wRow = 0; wRow < effectiveFilterHeight;\n wRow += dilationHeight) {\n const dyRow = (dyRowCorner + wRow) / strideHeight;\n if (dyRow < 0 || dyRow >= convInfo.outHeight ||\n Math.floor(dyRow) !== dyRow) {\n continue;\n }\n for (let wCol = 0; wCol < effectiveFilterWidth;\n wCol += dilationWidth) {\n const dyCol = (dyColCorner + wCol) / strideWidth;\n if (dyCol < 0 || dyCol >= convInfo.outWidth ||\n Math.floor(dyCol) !== dyCol) {\n continue;\n }\n\n const pixel =\n dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);\n dotProd += pixel;\n }\n }\n }\n dx.set(\n dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol,\n channel);\n }\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n assertNotComplex(x, 'maxPool3d');\n\n return this.pool3d(x, convInfo, 'max').toFloat();\n }\n\n private maxPool3dPositions(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n const maxPositions = ops.buffer(convInfo.outShape, 'int32');\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padFront = convInfo.padInfo.front;\n const padTop = convInfo.padInfo.top;\n const padLeft = convInfo.padInfo.left;\n\n const xBuf = this.bufferSync(x);\n for (let batch = 0; batch < convInfo.batchSize; ++batch) {\n for (let channel = 0; channel < convInfo.inChannels; ++channel) {\n for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {\n const xDepthCorner = yDepth * strideDepth - padFront;\n let xDepthMin = xDepthCorner;\n while (xDepthMin < 0) {\n xDepthMin += dilationDepth;\n }\n const xDepthMax =\n Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);\n for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {\n const xRowCorner = yRow * strideHeight - padTop;\n let xRowMin = xRowCorner;\n while (xRowMin < 0) {\n xRowMin += dilationHeight;\n }\n const xRowMax =\n Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);\n for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {\n const xColCorner = yCol * strideWidth - padLeft;\n let xColMin = xColCorner;\n while (xColMin < 0) {\n xColMin += dilationWidth;\n }\n const xColMax =\n Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);\n\n // Shader code begins\n let maxValue = Number.NEGATIVE_INFINITY;\n let maxPosition = -1;\n\n for (let xDepth = xDepthMin; xDepth < xDepthMax;\n xDepth += dilationDepth) {\n const wDepth = xDepth - xDepthCorner;\n for (let xRow = xRowMin; xRow < xRowMax;\n xRow += dilationHeight) {\n const wRow = xRow - xRowCorner;\n for (let xCol = xColMin; xCol < xColMax;\n xCol += dilationWidth) {\n const wCol = xCol - xColCorner;\n const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);\n if (pixel >= maxValue) {\n maxValue = pixel;\n maxPosition = wDepth * effectiveFilterHeight *\n effectiveFilterWidth +\n wRow * effectiveFilterHeight + wCol;\n }\n }\n }\n }\n\n maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);\n }\n }\n }\n }\n }\n return maxPositions.toTensor() as Tensor5D;\n }\n\n maxPool3dBackprop(\n dy: Tensor5D, x: Tensor5D, y: Tensor5D, convInfo: Conv3DInfo): Tensor5D {\n assertNotComplex([x, y], 'maxPool3dBackprop');\n\n const maxPositions = this.maxPool3dPositions(x, convInfo);\n const strideDepth = convInfo.strideDepth;\n const strideHeight = convInfo.strideHeight;\n const strideWidth = convInfo.strideWidth;\n const dilationDepth = convInfo.dilationDepth;\n const dilationHeight = convInfo.dilationHeight;\n const dilationWidth = convInfo.dilationWidth;\n const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;\n const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;\n const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;\n const dx = ops.buffer(x.shape, 'float32');\n\n const maxPosBuf = this.bufferSync(maxPositions);\n const dyBuf = this.bufferSync(dy);\n\n for (let batch = 0; batch < convInfo.batchSize; ++batch) {\n for (let channel = 0; channel < convInfo.inChannels; ++channel) {\n for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {\n for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {\n for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {\n // Shader code begins\n const dyDepthCorner = dxDepth - padFront;\n const dyRowCorner = dxRow - padTop;\n const dyColCorner = dxCol - padLeft;\n let dotProd = 0;\n for (let wDepth = 0; wDepth < effectiveFilterDepth;\n wDepth += dilationDepth) {\n const dyDepth = (dyDepthCorner + wDepth) / strideDepth;\n if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||\n Math.floor(dyDepth) !== dyDepth) {\n continue;\n }\n for (let wRow = 0; wRow < effectiveFilterHeight;\n wRow += dilationHeight) {\n const dyRow = (dyRowCorner + wRow) / strideHeight;\n if (dyRow < 0 || dyRow >= convInfo.outHeight ||\n Math.floor(dyRow) !== dyRow) {\n continue;\n }\n for (let wCol = 0; wCol < effectiveFilterWidth;\n wCol += dilationWidth) {\n const dyCol = (dyColCorner + wCol) / strideWidth;\n if (dyCol < 0 || dyCol >= convInfo.outWidth ||\n Math.floor(dyCol) !== dyCol) {\n continue;\n }\n\n const maxPos = effectiveFilterDepth *\n effectiveFilterHeight * effectiveFilterWidth -\n 1 -\n maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);\n const curPos =\n wDepth * effectiveFilterHeight * effectiveFilterWidth +\n wRow * effectiveFilterWidth + wCol;\n\n const mask = maxPos === curPos ? 1 : 0;\n if (mask === 0) {\n continue;\n }\n\n const pixel =\n dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);\n dotProd += pixel * mask;\n }\n }\n }\n dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);\n }\n }\n }\n }\n }\n return dx.toTensor();\n }\n\n cast(x: T, dtype: DataType): T {\n return backend_util.castTensor(x, dtype, this);\n }\n\n reshape(x: Tensor, shape: ShapeMap[R]): Tensor {\n return backend_util.reshapeTensor(x, shape);\n }\n\n avgPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {\n assertNotComplex(x, 'avgPool');\n assertNotComplex(x, 'maxPool');\n const xValues = this.readSync(x.dataId) as TypedArray;\n return pool(xValues, x.shape, x.dtype, x.strides, convInfo, 'avg')\n .toTensor()\n .toFloat() as Tensor4D;\n }\n\n resizeBilinear(\n x: Tensor4D, newHeight: number, newWidth: number,\n alignCorners: boolean): Tensor4D {\n assertNotComplex(x, 'resizeBilinear');\n\n const [batch, oldHeight, oldWidth, numChannels] = x.shape;\n const xValues = this.readSync(x.dataId) as TypedArray;\n const result = new Float32Array(\n util.sizeFromShape([batch, newHeight, newWidth, numChannels]));\n\n const effectiveInputSize: [number, number] = [\n (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,\n (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth\n ];\n\n const effectiveOutputSize: [number, number] = [\n (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,\n (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth\n ];\n let outputIdx = 0;\n const effectiveRowSizeRatio =\n effectiveInputSize[0] / effectiveOutputSize[0];\n const effectiveColSizeRatio =\n effectiveInputSize[1] / effectiveOutputSize[1];\n for (let b = 0; b < batch; b++) {\n for (let r = 0; r < newHeight; r++) {\n const sourceFracRow = effectiveRowSizeRatio * r;\n const sourceRowFloor = Math.floor(sourceFracRow);\n const rowFrac = sourceFracRow - sourceRowFloor;\n const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));\n const topRowOffset = b * x.strides[0] + sourceRowFloor * x.strides[1];\n const botRowOffset = b * x.strides[0] + sourceRowCeil * x.strides[1];\n for (let c = 0; c < newWidth; c++) {\n const sourceFracCol = effectiveColSizeRatio * c;\n const sourceColFloor = Math.floor(sourceFracCol);\n const colFrac = sourceFracCol - sourceColFloor;\n const sourceColCeil =\n Math.min(oldWidth - 1, Math.ceil(sourceFracCol));\n const topLeftOffest = topRowOffset + sourceColFloor * x.strides[2];\n const botLeftOffset = botRowOffset + sourceColFloor * x.strides[2];\n const topRightOffset = topRowOffset + sourceColCeil * x.strides[2];\n const botRightOffest = botRowOffset + sourceColCeil * x.strides[2];\n for (let d = 0; d < numChannels; d++) {\n // Begin shader.\n\n // Compute the fractional index of the source.\n const topLeft = xValues[topLeftOffest + d];\n const bottomLeft = xValues[botLeftOffset + d];\n const topRight = xValues[topRightOffset + d];\n const bottomRight = xValues[botRightOffest + d];\n\n const top = topLeft + (topRight - topLeft) * colFrac;\n const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;\n const newValue = top + (bottom - top) * rowFrac;\n\n result[outputIdx++] = newValue;\n }\n }\n }\n }\n return ops.tensor(result, [batch, newHeight, newWidth, numChannels]);\n }\n\n resizeBilinearBackprop(dy: Tensor4D, x: Tensor4D, alignCorners: boolean) {\n assertNotComplex([dy, x], 'resizeBilinearBackprop');\n\n const [batch, xHeight, xWidth, depth] = x.shape;\n const [, yHeight, yWidth] = dy.shape;\n\n const output = new Float32Array(batch * xHeight * xWidth * depth);\n\n // In the backwards pass, we want to find the pixels that were generated\n // for each pixel in the input image the forward pass and add the\n // corresponding coefficient from dy to the gradient (with some\n // interpolation).\n\n const effectiveXSize: [number, number] = [\n (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,\n (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth\n ];\n\n const effectiveYSize: [number, number] = [\n (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,\n (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth\n ];\n\n const heightScale = effectiveXSize[0] / effectiveYSize[0];\n const widthScale = effectiveXSize[1] / effectiveYSize[1];\n\n // Reference implementation\n // tslint:disable-next-line:max-line-length\n // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275\n\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n let offset = 0;\n for (let b = 0; b < batch; b++) {\n const bOffset = b * x.strides[0];\n for (let r = 0; r < yHeight; r++) {\n const dxR = r * heightScale;\n const topDxRIndex = Math.floor(dxR);\n const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);\n\n const topDxROffset = bOffset + topDxRIndex * x.strides[1];\n const bottomDxROffset = bOffset + bottomDxRIndex * x.strides[1];\n\n const dxRLerp = dxR - topDxRIndex;\n const inverseDxRLerp = 1.0 - dxRLerp;\n for (let c = 0; c < yWidth; c++) {\n const dxC = c * widthScale;\n const leftDxCIndex = Math.floor(dxC);\n const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);\n const dxCLerp = dxC - leftDxCIndex;\n const inverseDxCLerp = 1.0 - dxCLerp;\n\n const topLeftRCOffset = topDxROffset + leftDxCIndex * x.strides[2];\n const topRightRCOffset = topDxROffset + rightDxCIndex * x.strides[2];\n const bottomLeftRCOffset =\n bottomDxROffset + leftDxCIndex * x.strides[2];\n const bottomRightRCOffset =\n bottomDxROffset + rightDxCIndex * x.strides[2];\n\n const inverseDxRLerpTimesInverseDxCLerp =\n inverseDxRLerp * inverseDxCLerp;\n const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;\n const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;\n const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;\n for (let d = 0; d < depth; d++) {\n const dyVal = dyValues[offset++];\n output[topLeftRCOffset + d] +=\n dyVal * inverseDxRLerpTimesInverseDxCLerp;\n output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;\n output[bottomLeftRCOffset + d] +=\n dyVal * dxRLerpTimesInverseDxCLerp;\n output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;\n }\n }\n }\n }\n return ops.tensor4d(output, [batch, xWidth, xHeight, depth], x.dtype);\n }\n\n resizeNearestNeighbor(\n x: Tensor4D, newHeight: number, newWidth: number,\n alignCorners: boolean): Tensor4D {\n assertNotComplex(x, 'resizeNearestNeighbor');\n\n const [batch, oldHeight, oldWidth, numChannels] = x.shape;\n const xValues = this.readSync(x.dataId) as TypedArray;\n const output = new Float32Array(batch * newHeight * newWidth * numChannels);\n\n const effectiveInputSize: [number, number] = [\n (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,\n (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth\n ];\n\n const effectiveOutputSize: [number, number] = [\n (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,\n (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth\n ];\n\n const effectiveRowSizeRatio =\n effectiveInputSize[0] / effectiveOutputSize[0];\n const effectiveColSizeRatio =\n effectiveInputSize[1] / effectiveOutputSize[1];\n\n let outputOffset = 0;\n for (let b = 0; b < batch; b++) {\n const batchOffset = b * x.strides[0];\n for (let r = 0; r < newHeight; r++) {\n const sourceFracRow = effectiveRowSizeRatio * r;\n const sourceNearestRow = Math.min(\n oldHeight - 1,\n alignCorners ? Math.round(sourceFracRow) :\n Math.floor(sourceFracRow));\n const rowOffset = batchOffset + sourceNearestRow * x.strides[1];\n for (let c = 0; c < newWidth; c++) {\n const sourceFracCol = effectiveColSizeRatio * c;\n const sourceNearestCol = Math.min(\n oldWidth - 1,\n alignCorners ? Math.round(sourceFracCol) :\n Math.floor(sourceFracCol));\n const colOffset = rowOffset + sourceNearestCol * x.strides[2];\n for (let d = 0; d < numChannels; d++) {\n // Begin shader.\n // Compute the fractional index of the source.\n const newVal = xValues[colOffset + d];\n output[outputOffset++] = newVal;\n }\n }\n }\n }\n return ops.tensor(\n output, [batch, newHeight, newWidth, numChannels], x.dtype);\n }\n\n resizeNearestNeighborBackprop(\n dy: Tensor4D, x: Tensor4D, alignCorners: boolean) {\n assertNotComplex([dy, x], 'resizeNearestNeighborBackprop');\n\n const [batch, xHeight, xWidth, depth] = x.shape;\n const [, yHeight, yWidth] = dy.shape;\n\n const output = new Float32Array(batch * xHeight * xWidth * depth);\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n\n // In the backwards pass, we want to find the pixels that were generated\n // for each pixel in the input image the forward pass\n\n const effectiveXSize: [number, number] = [\n (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,\n (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth\n ];\n\n const effectiveYSize: [number, number] = [\n (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,\n (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth\n ];\n\n const heightScale = effectiveXSize[0] / effectiveYSize[0];\n const widthScale = effectiveXSize[1] / effectiveYSize[1];\n\n const invHeightScale = 1 / heightScale;\n const invWidthScale = 1 / widthScale;\n\n // This defines the size of the window of values around a particular\n // index in dy that we want to search for contributions to dx.\n const winHeight = (Math.ceil(invHeightScale) * 2) + 2;\n const winWidth = (Math.ceil(invWidthScale) * 2) + 2;\n\n // Loop over the output space.\n for (let b = 0; b < batch; b++) {\n const batchOffset = b * x.strides[0];\n for (let r = 0; r < xHeight; r++) {\n const rowOffset = batchOffset + r * x.strides[1];\n\n // Compute bounds for where in dy we will look\n const startRLerp = Math.floor(r * invHeightScale);\n const startDyR = Math.floor(startRLerp - (winHeight / 2));\n for (let c = 0; c < xWidth; c++) {\n const colOffset = rowOffset + c * x.strides[2];\n\n // Compute bounds for where in dy we will look\n const startCLerp = Math.floor(c * invWidthScale);\n const startDyC = Math.floor(startCLerp - (winWidth / 2));\n\n for (let d = 0; d < depth; d++) {\n let accum = 0;\n // loop over dy\n\n for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {\n const dyR = dyRIndex + startDyR;\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= yHeight) {\n continue;\n }\n\n const dyROffset = batchOffset + dyR * dy.strides[1];\n const sourceFracRow = dyR * heightScale;\n const sourceNearestRow = Math.min(\n xHeight - 1,\n alignCorners ? Math.round(sourceFracRow) :\n Math.floor(sourceFracRow));\n if (r !== sourceNearestRow) {\n continue;\n }\n for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {\n const dyC = dyCIndex + startDyC;\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= yWidth) {\n continue;\n }\n\n const dyCOffset = dyROffset + dyC * dy.strides[2];\n const sourceFracCol = dyC * widthScale;\n const sourceNearestCol = Math.min(\n xWidth - 1,\n alignCorners ? Math.round(sourceFracCol) :\n Math.floor(sourceFracCol));\n\n if (c === sourceNearestCol) {\n accum += dyValues[dyCOffset + d];\n }\n }\n }\n output[colOffset + d] = accum;\n }\n }\n }\n }\n return ops.tensor4d(output, x.shape, x.dtype);\n }\n\n batchNormalization(\n x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D,\n varianceEpsilon: number, scale?: Tensor4D|Tensor1D,\n offset?: Tensor4D|Tensor1D): Tensor4D {\n assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');\n\n const xVals = this.readSync(x.dataId) as TypedArray;\n const mVals = this.readSync(mean.dataId) as TypedArray;\n const varVals = this.readSync(variance.dataId) as TypedArray;\n const sVals = scale ? this.readSync(scale.dataId) as TypedArray :\n new Float32Array([1]);\n const offVals = offset ? this.readSync(offset.dataId) as TypedArray :\n new Float32Array([0]);\n const outVals = new Float32Array(xVals.length);\n\n const offValsLength = offVals.length;\n const sValsLength = sVals.length;\n const varValsLength = varVals.length;\n const mValsLength = mVals.length;\n\n let offi = 0;\n let mi = 0;\n let si = 0;\n let vi = 0;\n for (let i = 0; i < xVals.length; ++i) {\n outVals[i] = offVals[offi++] +\n (xVals[i] - mVals[mi++]) * sVals[si++] /\n Math.sqrt(varVals[vi++] + varianceEpsilon);\n if (offi >= offValsLength) {\n offi = 0;\n }\n if (mi >= mValsLength) {\n mi = 0;\n }\n if (si >= sValsLength) {\n si = 0;\n }\n if (vi >= varValsLength) {\n vi = 0;\n }\n }\n return tensor4d(outVals, x.shape);\n }\n\n localResponseNormalization4D(\n x: Tensor4D, depthRadius: number, bias: number, alpha: number,\n beta: number): Tensor4D {\n assertNotComplex(x, 'localResponseNormalization4D');\n\n const channels = x.shape[3];\n const maxD = channels - 1;\n const xValues = this.readSync(x.dataId) as TypedArray;\n const size = x.size;\n const result = new Float32Array(size);\n\n function sumAcrossChannels(offset: number) {\n const currentChannel = offset % channels;\n let beginSumOffset =\n offset - currentChannel + Math.max(0, currentChannel - depthRadius);\n const endSumOffset = offset - currentChannel +\n Math.min(currentChannel + depthRadius, maxD);\n\n let sum = 0.0;\n for (; beginSumOffset <= endSumOffset; beginSumOffset++) {\n const z = xValues[beginSumOffset];\n sum += z * z;\n }\n return sum;\n }\n\n for (let offset = 0; offset < size; offset++) {\n const sum = sumAcrossChannels(offset);\n const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);\n result[offset] = val;\n }\n\n return ops.tensor4d(result, x.shape);\n }\n\n LRNGrad(\n dy: Tensor4D, inputImage: Tensor4D, outputImage: Tensor4D,\n depthRadius: number, bias: number, alpha: number,\n beta: number): Tensor4D {\n assertNotComplex(dy, 'LRNGrad');\n const channels = dy.shape[3];\n const dyValues = this.readSync(dy.dataId) as TypedArray;\n const inputImageValues = this.readSync(inputImage.dataId) as TypedArray;\n const outputImageValues = this.readSync(outputImage.dataId) as TypedArray;\n const result = new Float32Array(dy.size);\n const size = dy.size;\n\n for (let offset = 0; offset < size; offset++) {\n const currentChannel = offset % channels;\n const depthBegin =\n (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);\n const depthEnd = (offset - currentChannel) +\n Math.min(channels, currentChannel + depthRadius + 1);\n\n let norm = 0;\n for (let k = depthBegin; k < depthEnd; k++) {\n norm += Math.pow(inputImageValues[k], 2);\n }\n norm = alpha * norm + bias;\n\n for (let k = depthBegin; k < depthEnd; k++) {\n let dyi = -2 * alpha * beta * inputImageValues[k] *\n outputImageValues[offset] / norm;\n if (offset === k) {\n dyi += Math.pow(norm, -beta);\n }\n dyi *= dyValues[offset];\n result[k] += dyi;\n }\n }\n return ops.tensor4d(result, dy.shape);\n }\n\n multinomial(\n logits: Tensor2D, normalized: boolean, numSamples: number,\n seed: number): Tensor2D {\n assertNotComplex(logits, 'multinomial');\n\n const probabilities = normalized ? logits : ops.softmax(logits);\n const batchSize = probabilities.shape[0];\n const numEvents = probabilities.shape[1];\n const res = ops.zeros([batchSize, numSamples], 'int32');\n const resVals = this.readSync(res.dataId) as TypedArray;\n const probVals = this.readSync(probabilities.dataId) as TypedArray;\n\n for (let b = 0; b < batchSize; ++b) {\n const offset = b * numEvents;\n // The cdf won't include the last event. It will be implicit if no other\n // event happened.\n const cdf = new Float32Array(numEvents - 1);\n cdf[0] = probVals[offset];\n for (let event = 1; event < cdf.length; ++event) {\n cdf[event] = cdf[event - 1] + probVals[offset + event];\n }\n\n const random = seedrandom.alea(seed.toString());\n const outOffset = b * numSamples;\n for (let sampleId = 0; sampleId < numSamples; ++sampleId) {\n const r = random();\n\n // Assume last event happened by default.\n resVals[outOffset + sampleId] = cdf.length;\n\n for (let event = 0; event < cdf.length; event++) {\n if (r < cdf[event]) {\n resVals[outOffset + sampleId] = event;\n break;\n }\n }\n }\n }\n return res;\n }\n\n oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number):\n Tensor2D {\n assertNotComplex(indices, 'oneHot');\n\n const res = new Float32Array(indices.size * depth);\n res.fill(offValue);\n const indicesVal = this.readSync(indices.dataId) as TypedArray;\n\n for (let event = 0; event < indices.size; ++event) {\n if (indicesVal[event] >= 0 && indicesVal[event] < depth) {\n res[event * depth + indicesVal[event]] = onValue;\n }\n }\n return ops.tensor2d(res, [indices.size, depth], 'int32');\n }\n\n nonMaxSuppression(\n boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,\n iouThreshold: number, scoreThreshold: number): Tensor1D {\n assertNotComplex(boxes, 'nonMaxSuppression');\n\n const boxesVals = this.readSync(boxes.dataId) as TypedArray;\n const scoresVals = this.readSync(scores.dataId) as TypedArray;\n return nonMaxSuppressionV3(\n boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);\n }\n\n fft(x: Tensor2D): Tensor2D {\n return this.fftBatch(x, false);\n }\n\n ifft(x: Tensor2D): Tensor2D {\n return this.fftBatch(x, true);\n }\n\n /**\n * Calculate FFT of inner most elements of batch tensor.\n */\n private fftBatch(x: Tensor2D, inverse: boolean): Tensor2D {\n const batch = x.shape[0];\n const innerDim = x.shape[1];\n // Collects real and imaginary values separately.\n const realResult = ops.buffer(x.shape, 'float32');\n const imagResult = ops.buffer(x.shape, 'float32');\n\n const real = ops.real(x).as2D(batch, innerDim);\n const imag = ops.imag(x).as2D(batch, innerDim);\n\n for (let b = 0; b < batch; b++) {\n // TODO: Support slice ops for complex type.\n const r = real.slice([b, 0], [1, innerDim]);\n const i = imag.slice([b, 0], [1, innerDim]);\n const input = ops.complex(r, i);\n // Run FFT by batch element.\n const res =\n this.readSync(this.fftImpl(input, inverse).dataId) as Float32Array;\n for (let d = 0; d < innerDim; d++) {\n const c = complex_util.getComplexWithIndex(res, d);\n realResult.values[b * innerDim + d] = c.real;\n imagResult.values[b * innerDim + d] = c.imag;\n }\n }\n\n const t = ops.complex(realResult.toTensor(), imagResult.toTensor());\n return t.as2D(batch, innerDim);\n }\n\n private fftImpl(x: Tensor2D, inverse: boolean): Tensor2D {\n const x1D = x.as1D();\n\n const n = x1D.size;\n\n if (this.isExponentOf2(n)) {\n let result = this.fftRadix2(x1D, n, inverse).as2D(x.shape[0], x.shape[1]);\n if (inverse) {\n result = ops.complex(\n ops.real(result).div(scalar(n)),\n ops.imag(result).div(scalar(n))) as Tensor2D;\n }\n return result;\n } else {\n const data = this.readSync(x.dataId) as TypedArray;\n const rawOutput =\n this.fourierTransformByMatmul(data, n, inverse) as Float32Array;\n const output = complex_util.splitRealAndImagArrays(rawOutput);\n return ops.complex(output.real, output.imag).as2D(x.shape[0], x.shape[1]);\n }\n }\n\n private isExponentOf2(size: number): boolean {\n return (size & size - 1) === 0;\n }\n\n // FFT using Cooley-Tukey algorithm on radix 2 dimensional input.\n private fftRadix2(input: Tensor1D, size: number, inverse: boolean): Tensor1D {\n if (size === 1) {\n return input;\n }\n const data = this.readSync(input.dataId) as TypedArray as Float32Array;\n const half = size / 2;\n const evenComplex = complex_util.complexWithEvenIndex(data);\n let evenTensor = ops.complex(evenComplex.real, evenComplex.imag).as1D();\n const oddComplex = complex_util.complexWithOddIndex(data);\n let oddTensor = ops.complex(oddComplex.real, oddComplex.imag).as1D();\n\n // Recursive call for half part of original input.\n evenTensor = this.fftRadix2(evenTensor, half, inverse);\n oddTensor = this.fftRadix2(oddTensor, half, inverse);\n\n const e = complex_util.exponents(size, inverse);\n const exponent = ops.complex(e.real, e.imag).mul(oddTensor);\n\n const addPart = evenTensor.add(exponent);\n const subPart = evenTensor.sub(exponent);\n\n const realTensor = ops.real(addPart).concat(ops.real(subPart));\n const imagTensor = ops.imag(addPart).concat(ops.imag(subPart));\n\n return ops.complex(realTensor, imagTensor).as1D();\n }\n\n // Calculate fourier transform by multplying sinusoid matrix.\n private fourierTransformByMatmul(\n data: TypedArray, size: number, inverse: boolean): TypedArray {\n const ret = new Float32Array(size * 2);\n // TODO: Use matmul instead once it supports complex64 type.\n for (let r = 0; r < size; r++) {\n let real = 0.0;\n let imag = 0.0;\n for (let c = 0; c < size; c++) {\n const e = complex_util.exponent(r * c, size, inverse);\n const term = complex_util.getComplexWithIndex(data as Float32Array, c);\n real += term.real * e.real - term.imag * e.imag;\n imag += term.real * e.imag + term.imag * e.real;\n }\n if (inverse) {\n real /= size;\n imag /= size;\n }\n complex_util.assignToTypedArray(ret, real, imag, r);\n }\n return ret;\n }\n\n depthToSpace(x: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):\n Tensor4D {\n util.assert(\n dataFormat === 'NHWC',\n () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${\n dataFormat}`);\n util.assert(\n blockSize > 1,\n () =>\n `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`);\n\n const batchSize = x.shape[0];\n const inputHeight = x.shape[1];\n const inputWidth = x.shape[2];\n const inputDepth = x.shape[3];\n\n const outputHeight = inputHeight * blockSize;\n const outputWidth = inputWidth * blockSize;\n const outputDepth = inputDepth / (blockSize * blockSize);\n\n const xValues = this.readSync(x.dataId) as TypedArray;\n const result =\n new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);\n\n let outputIdx = 0;\n for (let b = 0; b < batchSize; ++b) {\n for (let h = 0; h < outputHeight; ++h) {\n const inH = Math.floor(h / blockSize);\n const offsetH = (h % blockSize);\n for (let w = 0; w < outputWidth; ++w) {\n const inW = Math.floor(w / blockSize);\n const offsetW = (w % blockSize);\n const offsetD = (offsetH * blockSize + offsetW) * outputDepth;\n for (let d = 0; d < outputDepth; ++d) {\n const inD = d + offsetD;\n const inputIdx =\n inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));\n result[outputIdx++] = xValues[inputIdx];\n }\n }\n }\n }\n return ops.tensor4d(\n result, [batchSize, outputHeight, outputWidth, outputDepth]);\n }\n\n private broadcastedBinaryOp(\n a: Tensor, b: Tensor, dtype: DataType,\n op: (a: number, b: number) => number): Tensor {\n const newShape =\n broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape);\n const result = ops.buffer(newShape, dtype);\n const aVals = this.readSync(a.dataId) as TypedArray;\n const bVals = this.readSync(b.dataId) as TypedArray;\n const aBroadcastDims = broadcast_util.getBroadcastDims(a.shape, newShape);\n const bBroadcastDims = broadcast_util.getBroadcastDims(b.shape, newShape);\n\n const resVals = result.values;\n if (aBroadcastDims.length + bBroadcastDims.length === 0) {\n for (let i = 0; i < resVals.length; ++i) {\n resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);\n }\n } else {\n const aBuf = this.bufferSync(a);\n const bBuf = this.bufferSync(b);\n for (let i = 0; i < resVals.length; ++i) {\n const loc = result.indexToLoc(i);\n\n const aLoc = loc.slice(-a.rank);\n aBroadcastDims.forEach(d => aLoc[d] = 0);\n const aIndex = aBuf.locToIndex(aLoc);\n\n const bLoc = loc.slice(-b.rank);\n bBroadcastDims.forEach(d => bLoc[d] = 0);\n const bIndex = bBuf.locToIndex(bLoc);\n\n resVals[i] = op(aVals[aIndex], bVals[bIndex]);\n }\n }\n return result.toTensor();\n }\n\n private broadcastedBinaryComplexOp(\n a: Tensor, b: Tensor,\n op:\n (aReal: number, aImag: number, bReal: number,\n bImag: number) => {real: number, imag: number}): Tensor {\n const newShape =\n broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape);\n const realResult = ops.buffer(newShape, 'float32');\n const imagResult = ops.buffer(newShape, 'float32');\n\n const aVals = this.readSync(a.dataId) as TypedArray;\n const bVals = this.readSync(b.dataId) as TypedArray;\n const aBroadcastDims = broadcast_util.getBroadcastDims(a.shape, newShape);\n const bBroadcastDims = broadcast_util.getBroadcastDims(b.shape, newShape);\n\n const realVals = realResult.values;\n const imagVals = imagResult.values;\n\n if (aBroadcastDims.length + bBroadcastDims.length === 0) {\n for (let i = 0; i < realVals.length; i++) {\n const aIdx = i % aVals.length;\n const bIdx = i % bVals.length;\n\n const result =\n op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2],\n bVals[bIdx * 2 + 1]);\n\n realVals[i] = result.real;\n imagVals[i] = result.imag;\n }\n } else {\n const aRealBuf =\n this.bufferSync(this.data.get(a.dataId).complexTensors.real);\n const bRealBuf =\n this.bufferSync(this.data.get(b.dataId).complexTensors.real);\n for (let i = 0; i < realVals.length; i++) {\n const loc = realResult.indexToLoc(i);\n\n const aLoc = loc.slice(-a.rank);\n aBroadcastDims.forEach(d => aLoc[d] = 0);\n const aIndex = aRealBuf.locToIndex(aLoc);\n\n const bLoc = loc.slice(-b.rank);\n bBroadcastDims.forEach(d => bLoc[d] = 0);\n const bIndex = bRealBuf.locToIndex(bLoc);\n\n const opResult =\n op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2],\n bVals[bIndex * 2 + 1]);\n\n realVals[i] = opResult.real;\n imagVals[i] = opResult.imag;\n }\n }\n return this.complex(realResult.toTensor(), imagResult.toTensor());\n }\n\n split(x: T, sizeSplits: number[], axis: number): T[] {\n return split(x, sizeSplits, axis);\n }\n\n dispose() {}\n\n floatPrecision(): 16|32 {\n return 32;\n }\n /** Returns the smallest representable number. */\n epsilon(): number {\n return EPSILON_FLOAT32;\n }\n\n cropAndResize(\n images: Tensor4D,\n boxes: Tensor2D,\n boxIndex: Tensor1D,\n cropSize: [number, number],\n method: string,\n extrapolationValue: number,\n ) {\n const [batch, imageHeight, imageWidth, numChannels] = images.shape;\n const numBoxes = boxes.shape[0];\n\n const [cropHeight, cropWidth] = cropSize;\n const output =\n ops.buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');\n\n const boxVals = this.readSync(boxes.dataId) as TypedArray;\n const boxIndVals = this.readSync(boxIndex.dataId) as TypedArray;\n const imageVals = this.readSync(images.dataId) as TypedArray;\n\n const inStride = images.strides; // to calculate flat indexes into image\n const outStride = output.strides; // to calculate flat indexes into output\n\n // Reference implementation\n // tslint:disable-next-line:max-line-length\n // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc\n for (let b = 0; b < numBoxes; b++) {\n const startInd = b * 4;\n const y1 = boxVals[startInd];\n const x1 = boxVals[startInd + 1];\n const y2 = boxVals[startInd + 2];\n const x2 = boxVals[startInd + 3];\n\n const bInd: number = boxIndVals[b];\n if (bInd >= batch) {\n continue;\n }\n\n const heightScale = (cropHeight > 1) ?\n (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) :\n 0;\n const widthScale =\n (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;\n\n for (let y = 0; y < cropHeight; y++) {\n const yInd: number = (cropHeight > 1) ?\n y1 * (imageHeight - 1) + y * (heightScale) :\n 0.5 * (y1 + y2) * (imageHeight - 1);\n\n if (yInd < 0 || yInd > imageHeight - 1) {\n for (let x = 0; x < cropWidth; x++) {\n for (let c = 0; c < numChannels; c++) {\n const ind =\n c + x * outStride[2] + y * outStride[1] + b * outStride[0];\n output.values[ind] = extrapolationValue;\n }\n }\n continue;\n }\n\n if (method === 'bilinear') {\n const topInd = Math.floor(yInd);\n const bottomInd = Math.ceil(yInd);\n const yLerp = yInd - topInd;\n\n for (let x = 0; x < cropWidth; x++) {\n const xInd = (cropWidth > 1) ?\n x1 * (imageWidth - 1) + x * widthScale :\n 0.5 * (x1 + x2) * (imageWidth - 1);\n\n if (xInd < 0 || xInd > imageWidth - 1) {\n for (let c = 0; c < numChannels; c++) {\n const ind =\n c + x * outStride[2] + y * outStride[1] + b * outStride[0];\n output.values[ind] = extrapolationValue;\n }\n continue;\n }\n\n const leftInd = Math.floor(xInd);\n const rightInd = Math.ceil(xInd);\n const xLerp = xInd - leftInd;\n\n for (let c = 0; c < numChannels; c++) {\n let ind = c + leftInd * inStride[2] + topInd * inStride[1] +\n bInd * inStride[0];\n const topLeft = imageVals[ind];\n\n ind = c + rightInd * inStride[2] + topInd * inStride[1] +\n bInd * inStride[0];\n const topRight = imageVals[ind];\n\n ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +\n bInd * inStride[0];\n const bottomLeft = imageVals[ind];\n\n ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +\n bInd * inStride[0];\n const bottomRight = imageVals[ind];\n\n const top = topLeft + (topRight - topLeft) * xLerp;\n const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;\n\n ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];\n output.values[ind] = top + ((bottom - top) * yLerp);\n }\n }\n } else { // method == \"nearest\"\n for (let x = 0; x < cropWidth; ++x) {\n const xInd = (cropWidth > 1) ?\n x1 * (imageWidth - 1) + x * widthScale :\n 0.5 * (x1 + x2) * (imageWidth - 1);\n\n if (xInd < 0 || xInd > imageWidth - 1) {\n for (let c = 0; c < numChannels; c++) {\n const ind =\n c + x * outStride[2] + y * outStride[1] + b * outStride[0];\n output.values[ind] = extrapolationValue;\n }\n continue;\n }\n\n const closestX = Math.round(xInd);\n const closestY = Math.round(yInd);\n for (let c = 0; c < numChannels; c++) {\n const inInd = c + closestX * inStride[2] +\n closestY * inStride[1] + bInd * inStride[0];\n const outInd =\n c + x * outStride[2] + y * outStride[1] + b * outStride[0];\n output.values[outInd] = imageVals[inInd];\n }\n }\n }\n }\n }\n return output.toTensor() as Tensor4D;\n }\n\n sparseToDense(\n sparseIndices: Tensor, sparseValues: Tensor, outputShape: ShapeMap[R],\n defaultValue: Scalar): Tensor {\n const {sliceRank, numUpdates, sliceSize, strides, outputSize} =\n scatter_nd_util.calculateShapes(\n sparseValues, sparseIndices, outputShape);\n const sumDupeIndices = false;\n return this.scatter(\n sparseIndices, sparseValues, outputShape, outputSize, sliceSize,\n numUpdates, sliceRank, strides, defaultValue, sumDupeIndices);\n }\n\n gatherND(x: Tensor, indices: Tensor): Tensor {\n const indicesShape = indices.shape;\n const sliceRank = indicesShape[indicesShape.length - 1];\n\n const [resultShape, numSlices, sliceSize, strides] =\n gather_nd_util.prepareAndValidate(x, indices);\n if (numSlices === 0) {\n return tensor([], resultShape, x.dtype);\n }\n\n const buffer = new TensorBuffer([numSlices, sliceSize], x.dtype);\n const indicesData = this.readSync(indices.dataId) as TypedArray;\n const xData = this.readSync(x.dataId) as TypedArray;\n\n for (let i = 0; i < numSlices; i++) {\n const index = [];\n let flattenIndex = 0;\n for (let j = 0; j < sliceRank; j++) {\n const dim = indicesData[i * sliceRank + j];\n flattenIndex += dim * strides[j];\n index.push(dim);\n }\n if (flattenIndex < 0 || flattenIndex >= x.size / sliceSize) {\n throw new Error(\n `Invalid indices: ${index} does not index into ${x.shape}`);\n }\n\n for (let k = 0; k < sliceSize; k++) {\n buffer.values[i * sliceSize + k] = xData[flattenIndex * sliceSize + k];\n }\n }\n return buffer.toTensor().reshape(resultShape);\n }\n\n scatterND(\n indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor {\n const {sliceRank, numUpdates, sliceSize, strides, outputSize} =\n scatter_nd_util.calculateShapes(updates, indices, shape);\n const defaultValue = scalar(0);\n const sumDupeIndices = true;\n return this.scatter(\n indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank,\n strides, defaultValue, sumDupeIndices);\n }\n\n fill(\n shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor {\n dtype = dtype || inferDtype(value);\n const values = getArrayFromDType(dtype, sizeFromShape(shape)) as TypedArray;\n values.fill(value as number);\n return ENGINE.makeTensor(values, shape, dtype, this) as Tensor;\n }\n\n onesLike(x: Tensor): Tensor {\n if (x.dtype === 'string') {\n throw new Error('onesLike is not supported for string tensors');\n } else {\n return this.fill(x.shape, 1, x.dtype);\n }\n }\n\n zerosLike(x: Tensor): Tensor {\n const values =\n getArrayFromDType(x.dtype, sizeFromShape(x.shape)) as TypedArray;\n return this.makeOutput(values, x.shape, x.dtype);\n }\n\n linspace(start: number, stop: number, num: number): Tensor1D {\n return backend_util.linspaceImpl(start, stop, num);\n }\n\n private scatter(\n indices: Tensor, updates: Tensor, shape: ShapeMap[R], outputSize: number,\n sliceSize: number, numUpdates: number, sliceRank: number,\n strides: number[], defaultValue: Scalar,\n sumDupeIndices: boolean): Tensor {\n const flattenShape = [outputSize / sliceSize, sliceSize];\n\n const indicesData = this.readSync(indices.dataId) as TypedArray;\n const updatesData = this.readSync(updates.dataId) as TypedArray;\n\n if (outputSize === 0) {\n return tensor([], shape, updates.dtype);\n }\n\n const buffer = new TensorBuffer(flattenShape, updates.dtype as 'float32');\n buffer.values.fill((this.readSync(defaultValue.dataId) as TypedArray)[0]);\n\n for (let i = 0; i < numUpdates; i++) {\n const index = [];\n let flattenIndex = 0;\n for (let j = 0; j < sliceRank; j++) {\n const dim = indicesData[i * sliceRank + j];\n index.push(dim);\n flattenIndex += dim * strides[j];\n }\n\n if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {\n throw new Error(\n `Invalid indices: ${index} does not index into ${shape}`);\n }\n\n for (let k = 0; k < sliceSize; k++) {\n if (sumDupeIndices) {\n buffer.values[flattenIndex * sliceSize + k] +=\n updatesData[i * sliceSize + k];\n } else {\n buffer.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?\n updatesData[0] :\n updatesData[i * sliceSize + k];\n }\n }\n }\n return buffer.toTensor().reshape(shape);\n }\n}\n\nENGINE.registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as backend_util from '../../../backends/backend_util';\nimport {BinaryInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {DataType, NumericDataType, TypedArray} from '../../../types';\nimport * as util from '../../../util';\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport function createBinaryKernelConfig(\n name: string,\n op: (\n aShape: number[], bShape: number[], aVals: TypedArray,\n bVals: TypedArray,\n dtype: DataType) => [TypedArray, number[]]): KernelConfig {\n return {\n kernelName: name,\n backendName: 'cpu',\n kernelFunc: ({inputs, backend}) => {\n const {a, b} = inputs as BinaryInputs;\n const cpuBackend = backend as MathBackendCPU;\n assertNotComplex([a, b], name);\n\n const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;\n const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;\n\n const [resultData, resultShape] =\n op(a.shape, b.shape, aVals, bVals, a.dtype);\n\n const dataId = cpuBackend.write(resultData, resultShape, a.dtype);\n return {dataId, shape: resultShape, dtype: a.dtype};\n }\n };\n}\n\nexport function createBinaryKernelImpl(op: (a: number, b: number) => number) {\n return (aShape: number[], bShape: number[], aVals: TypedArray,\n bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => {\n const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);\n\n const resultRank = newShape.length;\n const resultStrides = util.computeStrides(newShape);\n const resultSize = util.sizeFromShape(newShape);\n\n const result =\n util.getTypedArrayFromDType(dtype as NumericDataType, resultSize);\n\n const aRank = aShape.length;\n const bRank = bShape.length;\n\n const aStrides = util.computeStrides(aShape);\n const bStrides = util.computeStrides(bShape);\n\n const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape);\n const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape);\n\n if (aBroadcastDims.length + bBroadcastDims.length === 0) {\n for (let i = 0; i < result.length; ++i) {\n result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);\n }\n } else {\n for (let i = 0; i < result.length; ++i) {\n const loc = util.indexToLoc(i, resultRank, resultStrides);\n\n const aLoc = loc.slice(-aRank);\n aBroadcastDims.forEach(d => aLoc[d] = 0);\n const aIndex = util.locToIndex(aLoc, aRank, aStrides);\n\n const bLoc = loc.slice(-bRank);\n bBroadcastDims.forEach(d => bLoc[d] = 0);\n const bIndex = util.locToIndex(bLoc, bRank, bStrides);\n\n result[i] = op(aVals[aIndex], bVals[bIndex]);\n }\n }\n\n return [result, newShape];\n };\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {createBinaryKernelImpl} from '../utils/kernel_utils';\n\nexport const divImpl = createBinaryKernelImpl((a: number, b: number) => a / b);\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Div} from '../../../kernel_names';\nimport {createBinaryKernelConfig} from '../utils/kernel_utils';\nimport {divImpl} from './Div_impl';\n\nexport const divConfig = createBinaryKernelConfig(Div, divImpl);\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {MaxPoolWithArgmax, MaxPoolWithArgmaxAttrs, MaxPoolWithArgmaxInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport * as conv_util from '../../../ops/conv_util';\nimport {TypedArray} from '../../../types';\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nimport {maxPoolWithArgmaxImpl} from './MaxPoolWithArgmax_impl';\n\nexport const maxPoolWithArgmaxConfig: KernelConfig = {\n kernelName: MaxPoolWithArgmax,\n backendName: 'cpu',\n kernelFunc: ({inputs, attrs, backend}) => {\n const {x} = inputs as MaxPoolWithArgmaxInputs;\n const {filterSize, strides, pad, includeBatchInIndex} =\n attrs as {} as MaxPoolWithArgmaxAttrs;\n const cpuBackend = backend as MathBackendCPU;\n assertNotComplex(x, 'MaxPoolWithArgmax');\n\n const values = cpuBackend.data.get(x.dataId).values as TypedArray;\n const convInfo = conv_util.computePool2DInfo(\n x.shape as [number, number, number, number], filterSize, strides,\n [1, 1], pad);\n const [pooled, indexes] = maxPoolWithArgmaxImpl(\n values, x.shape, x.dtype, includeBatchInIndex, convInfo);\n\n const pooledDataId =\n cpuBackend.write(pooled as Float32Array, convInfo.outShape, x.dtype);\n const indexesDataId =\n cpuBackend.write(indexes as Int32Array, convInfo.outShape, x.dtype);\n return [\n {dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype},\n {dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32'}\n ];\n }\n};\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Conv2DInfo} from '../../../ops/conv_util';\nimport {DataType, TypedArray} from '../../../types';\nimport {computeStrides} from '../../../util';\nimport {maxPoolPositions, pool} from '../pool_utils';\nexport function maxPoolWithArgmaxImpl(\n xValues: TypedArray, xShape: number[], dtype: DataType,\n includeBatchInIndex: boolean, convInfo: Conv2DInfo) {\n const strides = computeStrides(xShape);\n const maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max');\n const maxPositions = maxPoolPositions(\n xValues, xShape, dtype, convInfo, true, includeBatchInIndex);\n\n return [maxPools.values, maxPositions.values];\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {NonMaxSuppressionV5, NonMaxSuppressionV5Attrs, NonMaxSuppressionV5Inputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {TypedArray} from '../../../types';\nimport {nonMaxSuppressionV5} from '../../non_max_suppression_impl';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport const nonMaxSuppressionV5Config: KernelConfig = {\n kernelName: NonMaxSuppressionV5,\n backendName: 'cpu',\n kernelFunc: ({inputs, backend, attrs}) => {\n const {boxes, scores} = inputs as NonMaxSuppressionV5Inputs;\n const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} =\n attrs as unknown as NonMaxSuppressionV5Attrs;\n\n const cpuBackend = backend as MathBackendCPU;\n\n assertNotComplex(boxes, 'NonMaxSuppressionWithScore');\n\n const boxesVals = cpuBackend.data.get(boxes.dataId).values as TypedArray;\n const scoresVals = cpuBackend.data.get(scores.dataId).values as TypedArray;\n\n const maxOutputSizeVal = maxOutputSize;\n const iouThresholdVal = iouThreshold;\n const scoreThresholdVal = scoreThreshold;\n const softNmsSigmaVal = softNmsSigma;\n\n const {selectedIndices, selectedScores} = nonMaxSuppressionV5(\n boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal,\n scoreThresholdVal, softNmsSigmaVal);\n\n return [selectedIndices, selectedScores];\n }\n};\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Square, SquareInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport const squareConfig: KernelConfig = {\n kernelName: Square,\n backendName: 'cpu',\n kernelFunc: ({inputs, backend}) => {\n const {x} = inputs as SquareInputs;\n const cpuBackend = backend as MathBackendCPU;\n assertNotComplex(x, 'square');\n\n const values = cpuBackend.data.get(x.dataId).values as Float32Array;\n const newValues = new Float32Array(values.length);\n for (let i = 0; i < values.length; ++i) {\n const value = values[i];\n newValues[i] = value * value;\n }\n const dataId = cpuBackend.write(newValues, x.shape, x.dtype);\n return {dataId, shape: x.shape, dtype: x.dtype};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {SquaredDifference} from '../../../kernel_names';\nimport {createBinaryKernelImpl} from '../utils/kernel_utils';\nimport {createBinaryKernelConfig} from '../utils/kernel_utils';\n\nconst squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => {\n const diff = aVal - bVal;\n return diff * diff;\n});\n\nexport const squaredDifferenceConfig =\n createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl);\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, NumericDataType, TypedArray} from '../../../types';\nimport * as util from '../../../util';\n\nexport function transposeImpl(\n xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[],\n newShape: number[]): TypedArray {\n const xSize = util.sizeFromShape(xShape);\n const xRank = xShape.length;\n const xStrides = util.computeStrides(xShape);\n const newStrides = util.computeStrides(newShape);\n\n const result = util.getTypedArrayFromDType(\n dtype as NumericDataType, util.sizeFromShape(newShape));\n\n for (let i = 0; i < xSize; ++i) {\n const loc = util.indexToLoc(i, xRank, xStrides);\n\n // Permute location.\n const newLoc: number[] = new Array(loc.length);\n for (let i = 0; i < newLoc.length; i++) {\n newLoc[i] = loc[perm[i]];\n }\n\n const newIndex = util.locToIndex(newLoc, xRank, newStrides);\n result[newIndex] = xVals[i];\n }\n return result;\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n// We explicitly import the modular kernels so they get registered in the\n// global registry when we compile the library. A modular build would replace\n// the contents of this file and import only the kernels that are needed.\nimport {KernelConfig, registerKernel} from '../../kernel_registry';\n\nimport {divConfig} from './kernels/Div';\nimport {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';\nimport {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';\nimport {squareConfig} from './kernels/Square';\nimport {squaredDifferenceConfig} from './kernels/SquaredDifference';\nimport {transposeConfig} from './kernels/Transpose';\n\n// List all kernel configs here\nconst kernelConfigs: KernelConfig[] = [\n nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig,\n transposeConfig, maxPoolWithArgmaxConfig\n];\n\nfor (const kernelConfig of kernelConfigs) {\n registerKernel(kernelConfig);\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Transpose, TransposeAttrs, TransposeInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {TypedArray} from '../../../types';\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nimport {transposeImpl} from './Transpose_impl';\n\nexport const transposeConfig: KernelConfig = {\n kernelName: Transpose,\n backendName: 'cpu',\n kernelFunc: ({inputs, attrs, backend}) => {\n const {x} = inputs as TransposeInputs;\n const {perm} = attrs as {} as TransposeAttrs;\n const cpuBackend = backend as MathBackendCPU;\n\n assertNotComplex(x, 'transpose');\n\n const xRank = x.shape.length;\n\n const newShape: number[] = new Array(xRank);\n for (let i = 0; i < newShape.length; i++) {\n newShape[i] = x.shape[perm[i]];\n }\n\n const values = cpuBackend.data.get(x.dataId).values as TypedArray;\n const result = transposeImpl(values, x.shape, x.dtype, perm, newShape);\n\n const dataId = cpuBackend.write(result, newShape, x.dtype);\n return {dataId, shape: newShape, dtype: x.dtype};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Div, DivInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {divImpl} from './Div_impl';\n\nexport const divConfig: KernelConfig = {\n kernelName: Div,\n backendName: 'webgl',\n kernelFunc: ({inputs, backend}) => {\n const {a, b} = inputs as DivInputs;\n\n const webglBackend = backend as MathBackendWebGL;\n\n return divImpl(a, b, webglBackend);\n }\n};\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from '../../glsl_version';\nimport {GPGPUProgram} from '../../gpgpu_math';\n\nexport class FromPixelsProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n\n constructor(outputShape: number[]) {\n const glsl = getGlslDifferences();\n const [height, width, ] = outputShape;\n this.outputShape = outputShape;\n this.userCode = `\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);\n\n vec4 values = ${glsl.texture2D}(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2018 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getGlslDifferences} from '../../glsl_version';\nimport {GPGPUProgram} from '../../gpgpu_math';\n\nexport class FromPixelsPackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n userCode: string;\n outputShape: number[];\n packedInputs = false;\n packedOutput = true;\n\n constructor(outputShape: number[]) {\n const glsl = getGlslDifferences();\n const [height, width, ] = outputShape;\n this.outputShape = outputShape;\n this.userCode = `\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(${width}.0, ${height}.0);\n vec4 values = ${glsl.texture2D}(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n ${glsl.output} = result;\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {MaxPoolWithArgmax, MaxPoolWithArgmaxAttrs, MaxPoolWithArgmaxInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport * as conv_util from '../../../ops/conv_util';\nimport * as util from '../../../util';\nimport {MathBackendWebGL} from '../backend_webgl';\n\nimport {maxPoolWithArgmaxImpl} from './MaxPoolWithArgmax_impl';\n\nexport const maxPoolWithArgmaxConfig: KernelConfig = {\n kernelName: MaxPoolWithArgmax,\n backendName: 'webgl',\n kernelFunc: ({inputs, attrs, backend}) => {\n const {x} = inputs as MaxPoolWithArgmaxInputs;\n const {filterSize, strides, pad, includeBatchInIndex} =\n attrs as {} as MaxPoolWithArgmaxAttrs;\n const webglBackend = backend as MathBackendWebGL;\n\n util.assert(\n x.shape.length === 4,\n () => `Error in maxPool: input must be rank 4 but got rank ${\n x.shape.length}.`);\n const dilations: [number, number] = [1, 1];\n util.assert(\n conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n () => 'Error in maxPool: Either strides or dilations must be 1. ' +\n `Got strides ${strides} and dilations '${dilations}'`);\n\n const convInfo = conv_util.computePool2DInfo(\n x.shape as [number, number, number, number], filterSize, strides,\n dilations, pad);\n\n const [result, indexes] =\n maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend);\n return [result, indexes];\n }\n};\n","/**\n * @license\n * Copyright 2017 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class TransposeProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[];\n userCode: string;\n rank: number;\n\n constructor(aShape: number[], newDim: number[]) {\n const outputShape: number[] = new Array(aShape.length);\n for (let i = 0; i < outputShape.length; i++) {\n outputShape[i] = aShape[newDim[i]];\n }\n this.outputShape = outputShape;\n this.rank = outputShape.length;\n const dtype = getCoordsDataType(this.rank);\n const switched = getSwitchedCoords(newDim);\n\n this.userCode = `\n void main() {\n ${dtype} resRC = getOutputCoords();\n setOutput(getA(${switched}));\n }\n `;\n }\n}\n\nfunction getSwitchedCoords(newDim: number[]): string {\n const rank = newDim.length;\n if (rank > 6) {\n throw Error(`Transpose for rank ${rank} is not yet supported`);\n }\n const originalOrder =\n ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];\n const switchedCoords = new Array(rank);\n for (let i = 0; i < newDim.length; i++) {\n switchedCoords[newDim[i]] = originalOrder[i];\n }\n return switchedCoords.join();\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {getVecChannels} from '../packing_util';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class TransposePackedProgram implements GPGPUProgram {\n variableNames = ['A'];\n outputShape: number[];\n userCode: string;\n rank: number;\n packedInputs = true;\n packedOutput = true;\n\n constructor(aShape: number[], newDim: number[]) {\n const outputShape: number[] = new Array(aShape.length);\n for (let i = 0; i < outputShape.length; i++) {\n outputShape[i] = aShape[newDim[i]];\n }\n this.outputShape = outputShape;\n this.rank = outputShape.length;\n if (this.rank > 6) {\n throw Error(\n `Packed transpose for rank ${this.rank} is not yet supported.`);\n }\n const dtype = getCoordsDataType(this.rank);\n\n const outputOrder = getVecChannels('rc', this.rank);\n const switchedOrder = new Array(this.rank);\n for (let i = 0; i < newDim.length; i++) {\n switchedOrder[newDim[i]] = outputOrder[i];\n }\n const innerDims = `vec2(${switchedOrder.slice(-2).join()})`;\n const nextColumn =\n `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`;\n const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`;\n\n this.userCode = `\n void main() {\n ${dtype} rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = ${getc};\n if(${nextColumn}) {\n result[1] = ${getc};\n }\n --${outputOrder[this.rank - 1]};\n if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) {\n result[2] = ${getc};\n if(${nextColumn}) {\n result[3] = ${getc};\n }\n }\n setOutput(result);\n }\n `;\n }\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {KernelConfig, registerKernel} from '../../kernel_registry';\n\nimport {divConfig} from './kernels/Div';\nimport {fromPixelsConfig} from './kernels/FromPixels';\nimport {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';\nimport {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';\nimport {squareConfig} from './kernels/Square';\nimport {squaredDifferenceConfig} from './kernels/SquaredDifference';\nimport {transposeConfig} from './kernels/Transpose';\n\n// List all kernel configs here\nconst kernelConfigs: KernelConfig[] = [\n fromPixelsConfig, divConfig, nonMaxSuppressionV5Config, squareConfig,\n squaredDifferenceConfig, transposeConfig, maxPoolWithArgmaxConfig\n];\n\nfor (const kernelConfig of kernelConfigs) {\n registerKernel(kernelConfig);\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {TypedArray} from '../../../../src/types';\nimport {transposeImpl as cpuTranspose} from '../../../backends/cpu/kernels/Transpose_impl';\nimport {Transpose, TransposeAttrs, TransposeInputs} from '../../../kernel_names';\nimport {KernelConfig, TensorInfo} from '../../../kernel_registry';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {transposeImpl} from './Transpose_impl';\n\nexport const transposeConfig: KernelConfig = {\n kernelName: Transpose,\n backendName: 'webgl',\n kernelFunc: ({inputs, attrs, backend}) => {\n const {x} = inputs as TransposeInputs;\n const {perm} = attrs as {} as TransposeAttrs;\n const webglBackend = backend as MathBackendWebGL;\n\n const xRank = x.shape.length;\n\n const newShape: number[] = new Array(xRank);\n for (let i = 0; i < newShape.length; i++) {\n newShape[i] = x.shape[perm[i]];\n }\n\n let out: TensorInfo;\n if (webglBackend.shouldExecuteOnCPU([x])) {\n const xTexData = webglBackend.texData.get(x.dataId);\n const values = xTexData.values as TypedArray;\n const outValues = cpuTranspose(values, x.shape, x.dtype, perm, newShape);\n\n out = webglBackend.makeTensorInfo(newShape, x.dtype);\n const outData = webglBackend.texData.get(out.dataId);\n outData.values = outValues;\n } else {\n out = transposeImpl(x, perm, webglBackend);\n }\n return out;\n }\n};\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../../environment';\nimport {FromPixels, FromPixelsAttrs, FromPixelsInputs} from '../../../kernel_names';\nimport {KernelFunc, TensorInfo} from '../../../kernel_registry';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {TextureUsage} from '../tex_util';\n\nimport {FromPixelsProgram} from './FromPixels_utils/from_pixels_gpu';\nimport {FromPixelsPackedProgram} from './FromPixels_utils/from_pixels_packed_gpu';\n\nexport const fromPixelsConfig: KernelConfig = {\n kernelName: FromPixels,\n backendName: 'webgl',\n kernelFunc: fromPixels as {} as KernelFunc,\n};\n\nlet fromPixels2DContext: CanvasRenderingContext2D;\n\nfunction fromPixels(args: {\n inputs: FromPixelsInputs,\n backend: MathBackendWebGL,\n attrs: FromPixelsAttrs\n}): TensorInfo {\n const {inputs, backend, attrs} = args;\n let {pixels} = inputs;\n const {numChannels} = attrs;\n\n const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&\n pixels instanceof HTMLVideoElement;\n const isImage = typeof (HTMLImageElement) !== 'undefined' &&\n pixels instanceof HTMLImageElement;\n const [width, height] = isVideo ?\n [\n (pixels as HTMLVideoElement).videoWidth,\n (pixels as HTMLVideoElement).videoHeight\n ] :\n [pixels.width, pixels.height];\n\n const texShape: [number, number] = [height, width];\n const outShape = [height, width, numChannels];\n\n if (isImage || isVideo) {\n if (fromPixels2DContext == null) {\n fromPixels2DContext = document.createElement('canvas').getContext('2d');\n }\n\n fromPixels2DContext.canvas.width = width;\n fromPixels2DContext.canvas.height = height;\n fromPixels2DContext.drawImage(\n pixels as HTMLVideoElement | HTMLImageElement, 0, 0, width, height);\n pixels = fromPixels2DContext.canvas;\n }\n\n const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');\n // This is a byte texture with pixels.\n backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;\n backend.gpgpu.uploadPixelDataToTexture(\n backend.getTexture(tempPixelHandle.dataId), pixels as ImageData);\n const program = env().getBool('WEBGL_PACK') ?\n new FromPixelsPackedProgram(outShape) :\n new FromPixelsProgram(outShape);\n const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');\n backend.disposeData(tempPixelHandle.dataId);\n return res;\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../../environment';\nimport {TensorInfo} from '../../../kernel_registry';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport * as binaryop_gpu from '../binaryop_gpu';\nimport {BinaryOpProgram} from '../binaryop_gpu';\nimport * as binaryop_packed_gpu from '../binaryop_packed_gpu';\nimport {BinaryOpPackedProgram} from '../binaryop_packed_gpu';\n\nexport function divImpl(\n a: TensorInfo, b: TensorInfo, backend: MathBackendWebGL): TensorInfo {\n let program = new BinaryOpProgram(binaryop_gpu.DIV, a.shape, b.shape);\n if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n program = new BinaryOpPackedProgram(\n binaryop_packed_gpu.DIV, a.shape, b.shape, true);\n }\n const output = backend.runWebGLProgram(program, [a, b], 'float32');\n return output;\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {NonMaxSuppressionV5, NonMaxSuppressionV5Attrs, NonMaxSuppressionV5Inputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {warn} from '../../../log';\nimport {TypedArray} from '../../../types';\nimport {nonMaxSuppressionV5} from '../../non_max_suppression_impl';\n\nimport {MathBackendWebGL} from '../backend_webgl';\n\nexport const nonMaxSuppressionV5Config: KernelConfig = {\n kernelName: NonMaxSuppressionV5,\n backendName: 'webgl',\n kernelFunc: ({inputs, backend, attrs}) => {\n warn(\n 'tf.nonMaxSuppression() in webgl locks the UI thread. ' +\n 'Call tf.nonMaxSuppressionAsync() instead');\n\n const {boxes, scores} = inputs as NonMaxSuppressionV5Inputs;\n const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} =\n attrs as unknown as NonMaxSuppressionV5Attrs;\n\n const gpuBackend = backend as MathBackendWebGL;\n\n const boxesVals = gpuBackend.readSync(boxes.dataId) as TypedArray;\n const scoresVals = gpuBackend.readSync(scores.dataId) as TypedArray;\n\n const maxOutputSizeVal = maxOutputSize;\n const iouThresholdVal = iouThreshold;\n const scoreThresholdVal = scoreThreshold;\n const softNmsSigmaVal = softNmsSigma;\n\n const {selectedIndices, selectedScores} = nonMaxSuppressionV5(\n boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal,\n scoreThresholdVal, softNmsSigmaVal);\n\n return [selectedIndices, selectedScores];\n }\n};\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Square, SquareInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {SQUARE, UnaryOpProgram} from '../unaryop_gpu';\n\nexport const squareConfig: KernelConfig = {\n kernelName: Square,\n backendName: 'webgl',\n kernelFunc: ({inputs, backend}) => {\n const {x} = inputs as SquareInputs;\n const webglBackend = backend as MathBackendWebGL;\n const program = new UnaryOpProgram(x.shape, SQUARE);\n return webglBackend.runWebGLProgram(program, [x], x.dtype);\n }\n};\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../../environment';\nimport {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names';\nimport {KernelConfig} from '../../../kernel_registry';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {BinaryOpProgram} from '../binaryop_gpu';\nimport {BinaryOpPackedProgram} from '../binaryop_packed_gpu';\n\nexport const squaredDifferenceConfig: KernelConfig = {\n kernelName: SquaredDifference,\n backendName: 'webgl',\n kernelFunc: ({inputs, backend}) => {\n const {a, b} = inputs as SquaredDifferenceInputs;\n const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';\n const webGLBackend = backend as MathBackendWebGL;\n\n const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?\n new BinaryOpPackedProgram(SQUARED_DIFFERENCE, a.shape, b.shape) :\n new BinaryOpProgram(SQUARED_DIFFERENCE, a.shape, b.shape);\n return webGLBackend.compileAndRun(program, [a, b]);\n }\n};\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../../../environment';\nimport {TensorInfo} from '../../../kernel_registry';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {TransposeProgram} from '../transpose_gpu';\nimport {TransposePackedProgram} from '../transpose_packed_gpu';\n\nexport function transposeImpl(\n x: TensorInfo, perm: number[], backend: MathBackendWebGL): TensorInfo {\n const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?\n new TransposePackedProgram(x.shape, perm) :\n new TransposeProgram(x.shape, perm);\n return backend.runWebGLProgram(program, [x], x.dtype);\n}\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {TensorInfo} from '../../../kernel_registry';\nimport * as conv_util from '../../../ops/conv_util';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {Pool2DProgram} from '../pool_gpu';\n\nexport function maxPoolWithArgmaxImpl(\n x: TensorInfo, includeBatchInIndex: boolean, convInfo: conv_util.Conv2DInfo,\n backend: MathBackendWebGL): TensorInfo[] {\n let program = new Pool2DProgram(convInfo, 'max', false);\n const poolOutput = backend.runWebGLProgram(program, [x], 'float32');\n\n program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);\n const indexOutput = backend.runWebGLProgram(program, [x], 'float32');\n return [poolOutput, indexOutput];\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {addGradConfig} from './gradients/Add_grad';\nimport {addNGradConfig} from './gradients/AddN_grad';\nimport {broadcastToGradConfig} from './gradients/BroadcastTo_grad';\nimport {divGradConfig} from './gradients/Div_grad';\nimport {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad';\nimport {identityGradConfig} from './gradients/Identity_grad';\nimport {oneHotGradConfig} from './gradients/OneHot_grad';\nimport {padV2GradConfig} from './gradients/PadV2_grad';\nimport {squareGradConfig} from './gradients/Square_grad';\nimport {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad';\nimport {tileGradConfig} from './gradients/Tile_grad';\nimport {transposeGradConfig} from './gradients/Transpose_grad';\nimport {GradConfig} from './kernel_registry';\nimport {registerGradient} from './kernel_registry';\n\n// Export all kernel configs here so that the package can auto register them\nconst gradConfigs: GradConfig[] = [\n addGradConfig, addNGradConfig, broadcastToGradConfig, divGradConfig,\n fusedBatchNormGradConfig, identityGradConfig, oneHotGradConfig,\n padV2GradConfig, squareGradConfig, squaredDifferenceGradConfig,\n tileGradConfig, transposeGradConfig\n];\n\nfor (const gradientConfig of gradConfigs) {\n registerGradient(gradientConfig);\n}\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Add} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport * as broadcast_util from '../ops/broadcast_util';\nimport {Tensor} from '../tensor';\n\nexport const addGradConfig: GradConfig = {\n kernelName: Add,\n inputsToSave: ['a', 'b'],\n gradFunc: (dy: Tensor, saved: Tensor[]) => {\n const [a, b] = saved;\n const outShape =\n broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape);\n\n const derA = () => {\n let res = dy;\n const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape(a.shape);\n };\n const derB = () => {\n let res = dy;\n const reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape);\n if (reduceAxes.length > 0) {\n res = res.sum(reduceAxes);\n }\n return res.reshape(b.shape);\n };\n\n return {a: derA, b: derB};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {AddN} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport {Tensor} from '../tensor';\n\nexport const addNGradConfig: GradConfig = {\n kernelName: AddN,\n saveAllInputs: true,\n gradFunc: (dy: Tensor, saved: Tensor[]) => {\n const ders: {[key: string]: () => Tensor} = {};\n saved.forEach((_, i) => {\n ders[i] = () => dy.clone();\n });\n return ders;\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {BroadcastTo, BroadCastToAttrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport {Tensor} from '../tensor';\n\nexport const broadcastToGradConfig: GradConfig = {\n kernelName: BroadcastTo,\n gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n const broadCastToAttrs: BroadCastToAttrs =\n attrs as unknown as BroadCastToAttrs;\n\n const inputShape = broadCastToAttrs.inputShape;\n const outputShape = broadCastToAttrs.shape;\n\n const reps: number[] = Array.from(outputShape);\n for (let i = inputShape.length - 1; i >= 0; i--) {\n if (inputShape[i] === outputShape[i]) {\n reps[i] = 1;\n } else if (inputShape[i] !== 1) {\n throw new Error(`broadcastTo(): [${\n inputShape}] cannot be broadcast to [${outputShape}].`);\n }\n }\n const axes: number[] = [];\n for (let i = 0; i < reps.length; i++) {\n if (reps[i] > 1) {\n axes.push(i);\n }\n }\n const keepDims = true;\n return {x: () => dy.sum(axes, keepDims)};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Div} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport * as broadcast_util from '../ops/broadcast_util';\nimport {div} from '../ops/div';\nimport {sum} from '../ops/reduction_ops';\nimport {square} from '../ops/square';\nimport {neg} from '../ops/unary_ops';\nimport {Tensor} from '../tensor';\n\nexport const divGradConfig: GradConfig = {\n kernelName: Div,\n inputsToSave: ['a', 'b'],\n gradFunc: (dy: Tensor, saved: Tensor[]) => {\n const [a, b] = saved;\n const outShape =\n broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape);\n const derA = () => {\n const res = div(dy, b.toFloat());\n const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape);\n if (reduceAxes.length > 0) {\n return sum(res, reduceAxes).reshape(a.shape);\n }\n return res;\n };\n const derB = () => {\n let res = dy.mul(a.toFloat());\n const reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape);\n if (reduceAxes.length > 0) {\n res = sum(res, reduceAxes).reshape(b.shape);\n }\n const tmp = square(b);\n return neg(div(res, tmp.toFloat()));\n };\n return {a: derA, b: derB};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport {xAs4D} from '../ops/batchnorm_util';\nimport {getReductionAxes} from '../ops/broadcast_util';\nimport {add, mul, reshape, sub} from '../ops/ops';\nimport {sum} from '../ops/reduction_ops';\nimport {scalar} from '../ops/tensor_ops';\nimport {tile} from '../ops/tile';\nimport {rsqrt} from '../ops/unary_ops';\nimport {Tensor, Tensor4D} from '../tensor';\nimport {Rank, ShapeMap} from '../types';\n\nexport const fusedBatchNormGradConfig: GradConfig = {\n kernelName: FusedBatchNorm,\n inputsToSave: ['x', 'mean', 'variance', 'scale'],\n gradFunc: (\n dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n const batchNormalizationAttrs: FusedBatchNormAttrs =\n attrs as {} as FusedBatchNormAttrs;\n const {varianceEpsilon} = batchNormalizationAttrs;\n const [x, mean, variance, scale] = saved;\n\n const x4D: Tensor4D = xAs4D(x);\n\n const scaleValue = scale == null ? scalar(1) : scale;\n const reductionAxes = getReductionAxes(mean.shape, x4D.shape);\n const tileShape: number[] = [];\n if (mean.rank === 1) {\n for (let i = 0; i < x4D.shape.length - 1; ++i) {\n tileShape.push(x4D.shape[i]);\n }\n tileShape.push(1);\n }\n\n const xMinusMean = sub(x, mean);\n const dyTimesScaleValue = mul(dy, scaleValue);\n const oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon)));\n const minusHalfRCube = mul(\n mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance),\n scalar(-0.5));\n\n const derX = () => {\n if (mean.rank === 1) {\n return reshape(\n mul(mul(dy,\n tile(\n oneOverSqrtVariance.as4D(1, 1, 1, mean.shape[0]),\n tileShape)),\n scaleValue),\n x.shape);\n } else {\n return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);\n }\n };\n const derMean = () => {\n let meanDer =\n mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);\n if (mean.rank === 1) {\n meanDer = sum(meanDer, reductionAxes);\n }\n return reshape(meanDer, mean.shape as ShapeMap[R]);\n };\n const derVariance = () => {\n let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);\n\n if (mean.rank === 1) {\n varianceDer = sum(varianceDer, reductionAxes);\n }\n return reshape(varianceDer, mean.shape as ShapeMap[R]);\n };\n const derScale = () => {\n const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);\n\n let scaleDer = mul(dy, xMinusMean2TimesRsqrt);\n if (mean.rank === 1) {\n scaleDer = sum(scaleDer, reductionAxes);\n }\n return reshape(scaleDer, mean.shape as ShapeMap[R]);\n };\n const derOffset = () => {\n let offsetDer = dy;\n if (mean.rank === 1) {\n offsetDer = sum(offsetDer, reductionAxes);\n }\n return reshape(offsetDer, mean.shape as ShapeMap[R]);\n };\n return {\n x: derX,\n mean: derMean,\n variance: derVariance,\n scale: derScale,\n offset: derOffset\n };\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Identity} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport {Tensor} from '../tensor';\n\nexport const identityGradConfig: GradConfig = {\n kernelName: Identity,\n gradFunc: (dy: Tensor) => {\n return {x: () => dy.toFloat()};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {OneHot} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport {zeros} from '../ops/tensor_ops';\nimport {Tensor} from '../tensor';\n\nexport const oneHotGradConfig: GradConfig = {\n kernelName: OneHot,\n inputsToSave: ['indices'],\n gradFunc: (dy: Tensor, saved: Tensor[]) => {\n const indices = saved[0];\n return {indices: () => zeros(indices.shape, 'float32')};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {PadV2, PadV2Attrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport {Tensor} from '../tensor';\n\nexport const padV2GradConfig: GradConfig = {\n kernelName: PadV2,\n inputsToSave: ['x'],\n gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n // Pad introduces values around the original tensor, so the gradient\n // slices the original shape out of the gradient.\n const x = saved[0];\n const {paddings} = attrs as unknown as PadV2Attrs;\n const begin = paddings.map(p => p[0]);\n return {x: () => dy.slice(begin, x.shape)};\n }\n};\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Square} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport {Tensor} from '../tensor';\n\nexport const squareGradConfig: GradConfig = {\n kernelName: Square,\n inputsToSave: ['x'],\n gradFunc: (dy: Tensor, saved: Tensor[]) => {\n const [x] = saved;\n return {x: () => dy.mul(x.toFloat().mul(2))};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {SquaredDifference} from '../kernel_names';\nimport {GradConfig} from '../kernel_registry';\nimport {mul, sub} from '../ops/binary_ops';\nimport {scalar} from '../ops/tensor_ops';\nimport {Tensor} from '../tensor';\n\nexport const squaredDifferenceGradConfig: GradConfig = {\n kernelName: SquaredDifference,\n inputsToSave: ['a', 'b'],\n gradFunc: (dy: Tensor, saved: Tensor[]) => {\n const [a, b] = saved;\n const two = scalar(2);\n const derA = () => mul(dy, mul(two, sub(a, b)));\n const derB = () => mul(dy, mul(two, sub(b, a)));\n return {a: derA, b: derB};\n }\n};\n","/**\n * @license\n * Copyright 2020 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tile, TileAttrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport {zerosLike} from '../ops/tensor_ops';\nimport {Tensor} from '../tensor';\n\nexport const tileGradConfig: GradConfig = {\n kernelName: Tile,\n inputsToSave: ['x'],\n gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n const [x] = saved;\n const {reps} = attrs as unknown as TileAttrs;\n\n const derX = () => {\n let xGrad = zerosLike(x);\n // TODO(cais): Maybe reduce memory footprint by avoiding repeated\n // slicing.\n if (x.rank === 1) {\n for (let i = 0; i < reps[0]; ++i) {\n xGrad = xGrad.add(dy.slice([i * x.shape[0]], [x.shape[0]]));\n }\n } else if (x.rank === 2) {\n for (let i = 0; i < reps[0]; ++i) {\n for (let j = 0; j < reps[1]; ++j) {\n xGrad = xGrad.add(dy.slice(\n [i * x.shape[0], j * x.shape[1]], [x.shape[0], x.shape[1]]));\n }\n }\n } else if (x.rank === 3) {\n for (let i = 0; i < reps[0]; ++i) {\n for (let j = 0; j < reps[1]; ++j) {\n for (let k = 0; k < reps[2]; ++k) {\n xGrad = xGrad.add(dy.slice(\n [i * x.shape[0], j * x.shape[1], k * x.shape[2]],\n [x.shape[0], x.shape[1], x.shape[2]]));\n }\n }\n }\n } else if (x.rank === 4) {\n for (let i = 0; i < reps[0]; ++i) {\n for (let j = 0; j < reps[1]; ++j) {\n for (let k = 0; k < reps[2]; ++k) {\n for (let l = 0; l < reps[3]; ++l) {\n xGrad = xGrad.add(dy.slice(\n [\n i * x.shape[0], j * x.shape[1], k * x.shape[2],\n l * x.shape[3]\n ],\n [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));\n }\n }\n }\n }\n } else {\n throw new Error(\n `Gradient for tile operation is not implemented for rank-` +\n `${x.rank} tensors yet.`);\n }\n return xGrad;\n };\n return {x: derX};\n },\n};\n","/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Transpose, TransposeAttrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport * as axis_util from '../ops/axis_util';\nimport {transpose} from '../ops/transpose';\nimport {Tensor} from '../tensor';\n\nexport const transposeGradConfig: GradConfig = {\n kernelName: Transpose,\n gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n const transposeAttrs: TransposeAttrs = attrs as {} as TransposeAttrs;\n const {perm} = transposeAttrs;\n const undoPerm = axis_util.getUndoAxesPermutation(perm);\n return {x: () => transpose(dy, undoPerm)};\n }\n};\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../environment';\n\nimport {Platform} from './platform';\n\nexport class PlatformBrowser implements Platform {\n // According to the spec, the built-in encoder can do only UTF-8 encoding.\n // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder\n private textEncoder: TextEncoder;\n\n fetch(path: string, init?: RequestInit): Promise {\n return fetch(path, init);\n }\n\n now(): number {\n return performance.now();\n }\n\n encode(text: string, encoding: string): Uint8Array {\n if (encoding !== 'utf-8' && encoding !== 'utf8') {\n throw new Error(\n `Browser's encoder only supports utf-8, but got ${encoding}`);\n }\n if (this.textEncoder == null) {\n this.textEncoder = new TextEncoder();\n }\n return this.textEncoder.encode(text);\n }\n decode(bytes: Uint8Array, encoding: string): string {\n return new TextDecoder(encoding).decode(bytes);\n }\n}\n\nif (env().get('IS_BROWSER')) {\n env().setPlatform('browser', new PlatformBrowser());\n}\n","/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {env} from '../environment';\n\nimport {Platform} from './platform';\n\n// We are wrapping this within an object so it can be stubbed by Jasmine.\nexport const getNodeFetch = {\n // tslint:disable-next-line:no-require-imports\n importFetch: () => require('node-fetch')\n};\n\ntype FetchFn = (url: string, init?: RequestInit) => Promise;\nlet systemFetch: FetchFn;\n// These getters and setters are for testing so we don't export a mutable\n// variable.\nexport function resetSystemFetch() {\n systemFetch = null;\n}\nexport function setSystemFetch(fetchFn: FetchFn) {\n systemFetch = fetchFn;\n}\nexport function getSystemFetch(): FetchFn {\n return systemFetch;\n}\n\nexport class PlatformNode implements Platform {\n private textEncoder: TextEncoder;\n // tslint:disable-next-line:no-any\n util: any;\n\n constructor() {\n // tslint:disable-next-line:no-require-imports\n this.util = require('util');\n // According to the spec, the built-in encoder can do only UTF-8 encoding.\n // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder\n this.textEncoder = new this.util.TextEncoder();\n }\n\n fetch(path: string, requestInits?: RequestInit): Promise {\n if (env().global.fetch != null) {\n return env().global.fetch(path, requestInits);\n }\n\n if (systemFetch == null) {\n systemFetch = getNodeFetch.importFetch();\n }\n return systemFetch(path, requestInits);\n }\n\n now(): number {\n const time = process.hrtime();\n return time[0] * 1000 + time[1] / 1000000;\n }\n\n encode(text: string, encoding: string): Uint8Array {\n if (encoding !== 'utf-8' && encoding !== 'utf8') {\n throw new Error(\n `Node built-in encoder only supports utf-8, but got ${encoding}`);\n }\n return this.textEncoder.encode(text);\n }\n decode(bytes: Uint8Array, encoding: string): string {\n if (bytes.length === 0) {\n return '';\n }\n return new this.util.TextDecoder(encoding).decode(bytes);\n }\n}\n\nif (env().get('IS_NODE')) {\n env().setPlatform('node', new PlatformNode());\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/* Type definitions for exporting and importing of models. */\n\n/**\n * A map from Tensor dtype to number of bytes per element of the Tensor.\n */\nexport const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = {\n 'float32': 4,\n 'int32': 4,\n 'uint16': 2,\n 'uint8': 1,\n 'bool': 1,\n};\n\n/**\n * A weight manifest.\n *\n * The weight manifest consists of an ordered list of weight-manifest groups.\n * Each weight-manifest group (\"group\" for short hereafter) consists of a\n * number of weight values stored in a number of paths.\n * See the documentation of `WeightManifestGroupConfig` below for more details.\n */\nexport declare type WeightsManifestConfig = WeightsManifestGroupConfig[];\n\n/**\n * A weight-manifest group.\n *\n * Consists of an ordered list of weight values encoded in binary format,\n * stored in an ordered list of paths.\n */\nexport declare interface WeightsManifestGroupConfig {\n /**\n * An ordered list of paths.\n *\n * Paths are intentionally abstract in order to be general. For example, they\n * can be relative URL paths or relative paths on the file system.\n */\n paths: string[];\n\n /**\n * Specifications of the weights stored in the paths.\n */\n weights: WeightsManifestEntry[];\n}\n\n/**\n * Group to which the weight belongs.\n *\n * - 'optimizer': Weight from a stateful optimizer.\n */\nexport type WeightGroup = 'model'|'optimizer';\n\n/**\n * An entry in the weight manifest.\n *\n * The entry contains specification of a weight.\n */\nexport declare interface WeightsManifestEntry {\n /**\n * Name of the weight, e.g., 'Dense_1/bias'\n */\n name: string;\n\n /**\n * Shape of the weight.\n */\n shape: number[];\n\n /**\n * Data type of the weight.\n */\n dtype: 'float32'|'int32'|'bool'|'string';\n\n /**\n * Type of the weight.\n *\n * Optional.\n *\n * The value 'optimizer' indicates the weight belongs to an optimizer\n * (i.e., used only during model training and not during inference).\n */\n group?: WeightGroup;\n\n /**\n * Information for dequantization of the weight.\n */\n quantization?: {\n scale: number, // The scaling constant to multiply by.\n min: number, // The (possibly nudged) minimum weight to add.\n dtype: 'uint16'|'uint8' // The dtype of the quantized weights.\n };\n}\n\n/**\n * Options for saving a model.\n * @innamespace io\n */\nexport interface SaveConfig {\n /**\n * Whether to save only the trainable weights of the model, ignoring the\n * non-trainable ones.\n */\n trainableOnly?: boolean;\n\n /**\n * Whether the optimizer will be saved (if exists).\n *\n * Default: `false`.\n */\n includeOptimizer?: boolean;\n}\n\n/**\n * Result of a saving operation.\n */\nexport interface SaveResult {\n /**\n * Information about the model artifacts saved.\n */\n modelArtifactsInfo: ModelArtifactsInfo;\n\n /**\n * HTTP responses from the server that handled the model-saving request (if\n * any). This is applicable only to server-based saving routes.\n */\n responses?: Response[];\n\n /**\n * Error messages and related data (if any).\n */\n errors?: Array<{}|string>;\n}\n\nexport declare interface ModelArtifactsInfo {\n /**\n * Timestamp for when the model is saved.\n */\n dateSaved: Date;\n\n /**\n * TODO (cais,yassogba) consider removing GraphDef as GraphDefs now\n * come in a JSON format and none of our IOHandlers support a non json\n * format. We could conder replacing this with 'Binary' if we want to\n * allow future handlers to save to non json formats (though they will\n * probably want more information than 'Binary').\n * Type of the model topology\n *\n * Type of the model topology\n *\n * Possible values:\n * - JSON: JSON config (human-readable, e.g., Keras JSON).\n * - GraphDef: TensorFlow\n * [GraphDef](https://www.tensorflow.org/extend/tool_developers/#graphdef)\n * protocol buffer (binary).\n */\n modelTopologyType: 'JSON'|'GraphDef';\n\n /**\n * Size of model topology (Keras JSON or GraphDef), in bytes.\n */\n modelTopologyBytes?: number;\n\n /**\n * Size of weight specification or manifest, in bytes.\n */\n weightSpecsBytes?: number;\n\n /**\n * Size of weight value data, in bytes.\n */\n weightDataBytes?: number;\n}\n\n/** Model training configuration. */\nexport declare interface TrainingConfig {\n // TODO(cais): Tighten the typing once keras spec is available to tfjs-core.\n // See\n // tslint:disable-next-line:max-line-length\n // https://github.com/tensorflow/tfjs-layers/blob/master/src/keras_format/training_config.ts\n /** Optimizer used for the model training. */\n optimizer_config: {};\n\n // TODO(cais): Tighten the typing once keras spec is available to tfjs-core.\n /** Loss function(s) for the model's output(s). */\n loss: string|string[]|{[key: string]: string};\n\n // TODO(cais): Tighten the typing once keras spec is available to tfjs-core.\n /** Metric function(s) for the model's output(s). */\n metrics?: string[]|{[key: string]: string};\n\n // TODO(cais): Tighten the typing once keras spec is available to tfjs-core.\n weighted_metrics?: string[];\n\n // TODO(cais): Tighten the typing once keras spec is available to tfjs-core.\n sample_weight_mode?: string;\n\n loss_weights?: number[]|{[key: string]: number};\n}\n\n/**\n * The serialized artifacts of a model, including topology and weights.\n *\n * The `modelTopology`, `trainingConfig`, `weightSpecs` and `weightData` fields\n * of this interface are optional, in order to support topology- or weights-only\n * saving and loading.\n *\n * Note this interface is used internally in IOHandlers. For the file format\n * written to disk as `model.json`, see `ModelJSON`.\n */\nexport declare interface ModelArtifacts {\n /**\n * Model topology.\n *\n * For Keras-style `tf.Model`s, this is a JSON object.\n * For TensorFlow-style models (e.g., `SavedModel`), this is the JSON\n * encoding of the `GraphDef` protocol buffer.\n */\n modelTopology?: {}|ArrayBuffer;\n\n /**\n * Serialized configuration for the model's training.\n */\n trainingConfig?: TrainingConfig;\n\n /**\n * Weight specifications.\n *\n * This corresponds to the weightsData below.\n */\n weightSpecs?: WeightsManifestEntry[];\n\n /**\n * Binary buffer for all weight values concatenated in the order specified\n * by `weightSpecs`.\n */\n weightData?: ArrayBuffer;\n\n /**\n * Hard-coded format name for models saved from TensorFlow.js or converted\n * by TensorFlow.js Converter.\n */\n format?: string;\n\n /**\n * What library is responsible for originally generating this artifact.\n *\n * Used for debugging purposes. E.g., 'TensorFlow.js v1.0.0'.\n */\n generatedBy?: string;\n\n /**\n * What library or tool is responsible for converting the original model\n * to this format, applicable only if the model is output by a converter.\n *\n * Used for debugging purposes. E.g., 'TensorFlow.js Converter v1.0.0'.\n *\n * A value of `null` means the model artifacts are generated without any\n * conversion process (e.g., saved directly from a TensorFlow.js\n * `tf.LayersModel` instance.)\n */\n convertedBy?: string|null;\n\n /**\n * User-defined metadata about the model.\n */\n userDefinedMetadata?: {};\n}\n\n/**\n * The on-disk format of the `model.json` file.\n *\n * TF.js 1.0 always populates the optional fields when writing model.json.\n * Prior versions did not provide those fields.\n */\nexport declare interface ModelJSON {\n /**\n * Model topology.\n *\n * For Keras-style `tf.Model`s, this is a JSON object.\n * For TensorFlow-style models (e.g., `SavedModel`), this is the JSON\n * encoding of the `GraphDef` protocol buffer.\n */\n modelTopology: {};\n\n /** Model training configuration. */\n trainingConfig?: TrainingConfig;\n\n /**\n * Weights manifest.\n *\n * The weights manifest consists of an ordered list of weight-manifest\n * groups. Each weight-manifest group consists of a number of weight values\n * stored in a number of paths. See the documentation of\n * `WeightsManifestConfig` for more details.\n */\n weightsManifest: WeightsManifestConfig;\n\n /**\n * Hard-coded format name for models saved from TensorFlow.js or converted\n * by TensorFlow.js Converter.\n */\n format?: string;\n\n /**\n * What library is responsible for originally generating this artifact.\n *\n * Used for debugging purposes. E.g., 'TensorFlow.js v1.0.0'.\n */\n generatedBy?: string;\n\n /**\n * What library or tool is responsible for converting the original model\n * to this format, applicable only if the model is output by a converter.\n *\n * Used for debugging purposes. E.g., 'TensorFlow.js Converter v1.0.0'.\n *\n * A value of `null` means the model artifacts are generated without any\n * conversion process (e.g., saved directly from a TensorFlow.js\n * `tf.LayersModel` instance.)\n */\n convertedBy?: string|null;\n\n /**\n * User-defined metadata about the model.\n */\n userDefinedMetadata?: {};\n}\n\n/**\n * Type definition for handlers of loading operations.\n */\nexport type LoadHandler = () => Promise;\n\n/**\n * Type definition for handlers of saving operations.\n */\nexport type SaveHandler = (modelArtifact: ModelArtifacts) =>\n Promise;\n\n/**\n * Interface for a model import/export handler.\n *\n * The `save` and `load` handlers are both optional, in order to allow handlers\n * that support only saving or loading.\n */\n// tslint:disable-next-line:interface-name\nexport interface IOHandler {\n save?: SaveHandler;\n load?: LoadHandler;\n}\n\n/**\n * An interface for the manager of a model store.\n *\n * A model store is defined as a storage medium on which multiple models can\n * be stored. Each stored model has a unique `path` as its identifier.\n * A `ModelStoreManager` for the store allows actions including\n *\n * - Listing the models stored in the store.\n * - Deleting a model from the store.\n */\nexport interface ModelStoreManager {\n /**\n * List all models in the model store.\n *\n * @returns A dictionary mapping paths of existing models to their\n * model artifacts info. Model artifacts info include type of the model's\n * topology, byte sizes of the topology, weights, etc.\n */\n listModels(): Promise<{[path: string]: ModelArtifactsInfo}>;\n\n /**\n * Remove a model specified by `path`.\n *\n * @param path\n * @returns ModelArtifactsInfo of the deleted model (if and only if deletion\n * is successful).\n * @throws Error if deletion fails, e.g., if no model exists at `path`.\n */\n removeModel(path: string): Promise;\n}\n\n/**\n * Callback for the progress of a long-running action such as an HTTP\n * request for a large binary object.\n *\n * `fraction` should be a number in the [0, 1] interval, indicating how\n * much of the action has completed.\n */\nexport type OnProgressCallback = (fraction: number) => void;\n\n/** @innamespace io */\nexport interface LoadOptions {\n /**\n * RequestInit (options) for HTTP requests.\n *\n * For detailed information on the supported fields, see\n * [https://developer.mozilla.org/en-US/docs/Web/API/Request/Request](\n * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request)\n */\n requestInit?: RequestInit;\n\n /**\n * Progress callback.\n */\n onProgress?: OnProgressCallback;\n\n /**\n * A function used to override the `window.fetch` function.\n */\n fetchFunc?: Function;\n\n /**\n * Strict loading model: whether extraneous weights or missing\n * weights should trigger an `Error`.\n *\n * If `true`, require that the provided weights exactly match those\n * required by the layers. `false` means that both extra weights\n * and missing weights will be silently ignored.\n *\n * Default: `true`.\n */\n strict?: boolean;\n\n /**\n * Path prefix for weight files, by default this is calculated from the\n * path of the model JSON file.\n *\n * For instance, if the path to the model JSON file is\n * `http://localhost/foo/model.json`, then the default path prefix will be\n * `http://localhost/foo/`. If a weight file has the path value\n * `group1-shard1of2` in the weight manifest, then the weight file will be\n * loaded from `http://localhost/foo/group1-shard1of2` by default. However,\n * if you provide a `weightPathPrefix` value of\n * `http://localhost/foo/alt-weights`, then the weight file will be loaded\n * from the path `http://localhost/foo/alt-weights/group1-shard1of2` instead.\n */\n weightPathPrefix?: string;\n\n /**\n * Whether the module or model is to be loaded from TF Hub.\n *\n * Setting this to `true` allows passing a TF-Hub module URL, omitting the\n * standard model file name and the query parameters.\n *\n * Default: `false`.\n */\n fromTFHub?: boolean;\n}\n\n/**\n * Additional options for Platform.fetch\n */\nexport interface RequestDetails {\n /**\n * Is this request for a binary file (as opposed to a json file)\n */\n isBinary?: boolean;\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {tensor} from '../ops/tensor_ops';\nimport {NamedTensor, NamedTensorMap} from '../tensor_types';\nimport {TypedArray} from '../types';\nimport {sizeFromShape} from '../util';\n\nimport {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types';\n\n/** Number of bytes reserved for the length of the string. (32bit integer). */\nconst NUM_BYTES_STRING_LENGTH = 4;\n\n/**\n * Encode a map from names to weight values as an ArrayBuffer, along with an\n * `Array` of `WeightsManifestEntry` as specification of the encoded weights.\n *\n * This function does not perform sharding.\n *\n * This function is the reverse of `decodeWeights`.\n *\n * @param tensors A map (\"dict\") from names to tensors.\n * @param group Group to which the weights belong (optional).\n * @returns A `Promise` of\n * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s\n * concatenated.\n * - An `Array` of `WeightManifestEntry`s, carrying information including\n * tensor names, `dtype`s and shapes.\n * @throws Error: on unsupported tensor `dtype`.\n */\nexport async function encodeWeights(\n tensors: NamedTensorMap|NamedTensor[], group?: WeightGroup):\n Promise<{data: ArrayBuffer, specs: WeightsManifestEntry[]}> {\n // TODO(adarob, cais): Support quantization.\n const specs: WeightsManifestEntry[] = [];\n const dataPromises: Array> = [];\n\n const names: string[] = Array.isArray(tensors) ?\n tensors.map(tensor => tensor.name) :\n Object.keys(tensors);\n\n for (let i = 0; i < names.length; ++i) {\n const name = names[i];\n const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];\n if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&\n t.dtype !== 'string') {\n throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);\n }\n const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype};\n if (t.dtype === 'string') {\n const utf8bytes = new Promise(async resolve => {\n const vals = await t.bytes() as Uint8Array[];\n const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +\n NUM_BYTES_STRING_LENGTH * vals.length;\n const bytes = new Uint8Array(totalNumBytes);\n let offset = 0;\n for (let i = 0; i < vals.length; i++) {\n const val = vals[i];\n const bytesOfLength =\n new Uint8Array(new Uint32Array([val.length]).buffer);\n bytes.set(bytesOfLength, offset);\n offset += NUM_BYTES_STRING_LENGTH;\n bytes.set(val, offset);\n offset += val.length;\n }\n resolve(bytes);\n });\n dataPromises.push(utf8bytes);\n } else {\n dataPromises.push(t.data());\n }\n if (group != null) {\n spec.group = group;\n }\n specs.push(spec);\n }\n\n const tensorValues = await Promise.all(dataPromises);\n return {data: concatenateTypedArrays(tensorValues), specs};\n}\n\n/**\n * Decode flat ArrayBuffer as weights.\n *\n * This function does not handle sharding.\n *\n * This function is the reverse of `encodeWeights`.\n *\n * @param buffer A flat ArrayBuffer carrying the binary values of the tensors\n * concatenated in the order specified in `specs`.\n * @param specs Specifications of the names, dtypes and shapes of the tensors\n * whose value are encoded by `buffer`.\n * @return A map from tensor name to tensor value, with the names corresponding\n * to names in `specs`.\n * @throws Error, if any of the tensors has unsupported dtype.\n */\nexport function decodeWeights(\n buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap {\n // TODO(adarob, cais): Support quantization.\n const out: NamedTensorMap = {};\n let offset = 0;\n for (const spec of specs) {\n const name = spec.name;\n const dtype = spec.dtype;\n const shape = spec.shape;\n const size = sizeFromShape(shape);\n let values: TypedArray|string[]|Uint8Array[];\n\n if ('quantization' in spec) {\n const quantization = spec.quantization;\n if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {\n throw new Error(\n `Weight ${spec.name} has unknown ` +\n `quantization dtype ${quantization.dtype}. ` +\n `Supported quantization dtypes are: 'uint8' and 'uint16'.`);\n }\n const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];\n const byteBuffer =\n buffer.slice(offset, offset + size * quantizationSizeFactor);\n const quantizedArray = (quantization.dtype === 'uint8') ?\n new Uint8Array(byteBuffer) :\n new Uint16Array(byteBuffer);\n if (dtype === 'float32') {\n values = Float32Array.from(\n quantizedArray, v => v * quantization.scale + quantization.min);\n } else if (dtype === 'int32') {\n values = Int32Array.from(\n quantizedArray,\n v => Math.round(v * quantization.scale + quantization.min));\n } else {\n throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);\n }\n offset += size * quantizationSizeFactor;\n } else if (dtype === 'string') {\n const size = sizeFromShape(spec.shape);\n values = [];\n for (let i = 0; i < size; i++) {\n const byteLength = new Uint32Array(\n buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];\n offset += NUM_BYTES_STRING_LENGTH;\n const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));\n (values as Uint8Array[]).push(bytes);\n offset += byteLength;\n }\n } else {\n const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];\n const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);\n\n if (dtype === 'float32') {\n values = new Float32Array(byteBuffer);\n } else if (dtype === 'int32') {\n values = new Int32Array(byteBuffer);\n } else if (dtype === 'bool') {\n values = new Uint8Array(byteBuffer);\n } else {\n throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);\n }\n offset += size * dtypeFactor;\n }\n\n out[name] = tensor(values, shape, dtype);\n }\n return out;\n}\n\n/**\n * Concatenate TypedArrays into an ArrayBuffer.\n */\nexport function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer {\n // TODO(adarob, cais): Support quantization.\n if (xs === null) {\n throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);\n }\n\n let totalByteLength = 0;\n\n // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'\n // can have a different byte length from that of the `TypedArray` itself,\n // for example, when the `TypedArray` is created from an offset in an\n // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match\n // the `TypedArray` in byte length. If an element of `xs` does not show\n // this property, a new `TypedArray` that satisfy this property will be\n // constructed and pushed into `normalizedXs`.\n const normalizedXs: TypedArray[] = [];\n xs.forEach((x: TypedArray) => {\n totalByteLength += x.byteLength;\n // tslint:disable:no-any\n normalizedXs.push(\n x.byteLength === x.buffer.byteLength ? x :\n new (x.constructor as any)(x));\n if (!(x as any instanceof Float32Array || x as any instanceof Int32Array ||\n x as any instanceof Uint8Array)) {\n throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);\n }\n // tslint:enable:no-any\n });\n\n const y = new Uint8Array(totalByteLength);\n let offset = 0;\n normalizedXs.forEach((x: TypedArray) => {\n y.set(new Uint8Array(x.buffer), offset);\n offset += x.byteLength;\n });\n\n return y.buffer;\n}\n\n// Use Buffer on Node.js instead of Blob/atob/btoa\nconst useNodeBuffer = typeof Buffer !== 'undefined' &&\n (typeof Blob === 'undefined' || typeof atob === 'undefined' ||\n typeof btoa === 'undefined');\n\n/**\n * Calculate the byte length of a JavaScript string.\n *\n * Note that a JavaScript string can contain wide characters, therefore the\n * length of the string is not necessarily equal to the byte length.\n *\n * @param str Input string.\n * @returns Byte length.\n */\nexport function stringByteLength(str: string): number {\n if (useNodeBuffer) {\n return Buffer.byteLength(str);\n }\n return new Blob([str]).size;\n}\n\n/**\n * Encode an ArrayBuffer as a base64 encoded string.\n *\n * @param buffer `ArrayBuffer` to be converted.\n * @returns A string that base64-encodes `buffer`.\n */\nexport function arrayBufferToBase64String(buffer: ArrayBuffer): string {\n if (useNodeBuffer) {\n return Buffer.from(buffer).toString('base64');\n }\n const buf = new Uint8Array(buffer);\n let s = '';\n for (let i = 0, l = buf.length; i < l; i++) {\n s += String.fromCharCode(buf[i]);\n }\n return btoa(s);\n}\n\n/**\n * Decode a base64 string as an ArrayBuffer.\n *\n * @param str Base64 string.\n * @returns Decoded `ArrayBuffer`.\n */\nexport function base64StringToArrayBuffer(str: string): ArrayBuffer {\n if (useNodeBuffer) {\n const buf = Buffer.from(str, 'base64');\n return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);\n }\n const s = atob(str);\n const buffer = new Uint8Array(s.length);\n for (let i = 0; i < s.length; ++i) {\n buffer.set([s.charCodeAt(i)], i);\n }\n return buffer.buffer;\n}\n\n/**\n * Concatenate a number of ArrayBuffers into one.\n *\n * @param buffers A number of array buffers to concatenate.\n * @returns Result of concatenating `buffers` in order.\n */\nexport function concatenateArrayBuffers(buffers: ArrayBuffer[]): ArrayBuffer {\n let totalByteLength = 0;\n buffers.forEach((buffer: ArrayBuffer) => {\n totalByteLength += buffer.byteLength;\n });\n\n const temp = new Uint8Array(totalByteLength);\n let offset = 0;\n buffers.forEach((buffer: ArrayBuffer) => {\n temp.set(new Uint8Array(buffer), offset);\n offset += buffer.byteLength;\n });\n return temp.buffer;\n}\n\n/**\n * Get the basename of a path.\n *\n * Behaves in a way analogous to Linux's basename command.\n *\n * @param path\n */\nexport function basename(path: string): string {\n const SEPARATOR = '/';\n path = path.trim();\n while (path.endsWith(SEPARATOR)) {\n path = path.slice(0, path.length - 1);\n }\n const items = path.split(SEPARATOR);\n return items[items.length - 1];\n}\n\n/**\n * Populate ModelArtifactsInfo fields for a model with JSON topology.\n * @param modelArtifacts\n * @returns A ModelArtifactsInfo object.\n */\nexport function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts):\n ModelArtifactsInfo {\n if (modelArtifacts.modelTopology instanceof ArrayBuffer) {\n throw new Error('Expected JSON model topology, received ArrayBuffer.');\n }\n\n return {\n dateSaved: new Date(),\n modelTopologyType: 'JSON',\n modelTopologyBytes: modelArtifacts.modelTopology == null ?\n 0 :\n stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),\n weightSpecsBytes: modelArtifacts.weightSpecs == null ?\n 0 :\n stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),\n weightDataBytes: modelArtifacts.weightData == null ?\n 0 :\n modelArtifacts.weightData.byteLength,\n };\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {IOHandler} from './types';\n\nexport type IORouter = (url: string|string[], onProgress?: Function) =>\n IOHandler;\n\nexport class IORouterRegistry {\n // Singleton instance.\n private static instance: IORouterRegistry;\n\n private saveRouters: IORouter[];\n private loadRouters: IORouter[];\n\n private constructor() {\n this.saveRouters = [];\n this.loadRouters = [];\n }\n\n private static getInstance(): IORouterRegistry {\n if (IORouterRegistry.instance == null) {\n IORouterRegistry.instance = new IORouterRegistry();\n }\n return IORouterRegistry.instance;\n }\n\n /**\n * Register a save-handler router.\n *\n * @param saveRouter A function that maps a URL-like string onto an instance\n * of `IOHandler` with the `save` method defined or `null`.\n */\n static registerSaveRouter(saveRouter: IORouter) {\n IORouterRegistry.getInstance().saveRouters.push(saveRouter);\n }\n\n /**\n * Register a load-handler router.\n *\n * @param loadRouter A function that maps a URL-like string onto an instance\n * of `IOHandler` with the `load` method defined or `null`.\n */\n static registerLoadRouter(loadRouter: IORouter) {\n IORouterRegistry.getInstance().loadRouters.push(loadRouter);\n }\n\n /**\n * Look up IOHandler for saving, given a URL-like string.\n *\n * @param url\n * @returns If only one match is found, an instance of IOHandler with the\n * `save` method defined. If no match is found, `null`.\n * @throws Error, if more than one match is found.\n */\n static getSaveHandlers(url: string|string[]): IOHandler[] {\n return IORouterRegistry.getHandlers(url, 'save');\n }\n\n /**\n * Look up IOHandler for loading, given a URL-like string.\n *\n * @param url\n * @param onProgress Optional, progress callback function, fired periodically\n * before the load is completed.\n * @returns All valid handlers for `url`, given the currently registered\n * handler routers.\n */\n static getLoadHandlers(url: string|string[], onProgress?: Function):\n IOHandler[] {\n return IORouterRegistry.getHandlers(url, 'load', onProgress);\n }\n\n private static getHandlers(\n url: string|string[], handlerType: 'save'|'load',\n onProgress?: Function): IOHandler[] {\n const validHandlers: IOHandler[] = [];\n const routers = handlerType === 'load' ?\n IORouterRegistry.getInstance().loadRouters :\n IORouterRegistry.getInstance().saveRouters;\n routers.forEach(router => {\n const handler = router(url, onProgress);\n if (handler !== null) {\n validHandlers.push(handler);\n }\n });\n return validHandlers;\n }\n}\n\nexport const registerSaveRouter = (loudRouter: IORouter) =>\n IORouterRegistry.registerSaveRouter(loudRouter);\nexport const registerLoadRouter = (loudRouter: IORouter) =>\n IORouterRegistry.registerLoadRouter(loudRouter);\nexport const getSaveHandlers = (url: string|string[]) =>\n IORouterRegistry.getSaveHandlers(url);\nexport const getLoadHandlers = (url: string|string[], onProgress?: Function) =>\n IORouterRegistry.getLoadHandlers(url, onProgress);\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Classes and functions for model management across multiple storage mediums.\n *\n * Supported client actions:\n * - Listing models on all registered storage mediums.\n * - Remove model by URL from any registered storage mediums, by using URL\n * string.\n * - Moving or copying model from one path to another in the same medium or from\n * one medium to another, by using URL strings.\n */\n\nimport {assert} from '../util';\n\nimport {IORouterRegistry} from './router_registry';\nimport {ModelArtifactsInfo, ModelStoreManager} from './types';\n\nconst URL_SCHEME_SUFFIX = '://';\n\nexport class ModelStoreManagerRegistry {\n // Singleton instance.\n private static instance: ModelStoreManagerRegistry;\n\n private managers: {[scheme: string]: ModelStoreManager};\n\n private constructor() {\n this.managers = {};\n }\n\n private static getInstance(): ModelStoreManagerRegistry {\n if (ModelStoreManagerRegistry.instance == null) {\n ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();\n }\n return ModelStoreManagerRegistry.instance;\n }\n\n /**\n * Register a save-handler router.\n *\n * @param saveRouter A function that maps a URL-like string onto an instance\n * of `IOHandler` with the `save` method defined or `null`.\n */\n static registerManager(scheme: string, manager: ModelStoreManager) {\n assert(scheme != null, () => 'scheme must not be undefined or null.');\n if (scheme.endsWith(URL_SCHEME_SUFFIX)) {\n scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));\n }\n assert(scheme.length > 0, () => 'scheme must not be an empty string.');\n const registry = ModelStoreManagerRegistry.getInstance();\n assert(\n registry.managers[scheme] == null,\n () => `A model store manager is already registered for scheme '${\n scheme}'.`);\n registry.managers[scheme] = manager;\n }\n\n static getManager(scheme: string): ModelStoreManager {\n const manager = this.getInstance().managers[scheme];\n if (manager == null) {\n throw new Error(`Cannot find model manager for scheme '${scheme}'`);\n }\n return manager;\n }\n\n static getSchemes(): string[] {\n return Object.keys(this.getInstance().managers);\n }\n}\n\n/**\n * Helper method for parsing a URL string into a scheme and a path.\n *\n * @param url E.g., 'localstorage://my-model'\n * @returns A dictionary with two fields: scheme and path.\n * Scheme: e.g., 'localstorage' in the example above.\n * Path: e.g., 'my-model' in the example above.\n */\nfunction parseURL(url: string): {scheme: string, path: string} {\n if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {\n throw new Error(\n `The url string provided does not contain a scheme. ` +\n `Supported schemes are: ` +\n `${ModelStoreManagerRegistry.getSchemes().join(',')}`);\n }\n return {\n scheme: url.split(URL_SCHEME_SUFFIX)[0],\n path: url.split(URL_SCHEME_SUFFIX)[1],\n };\n}\n\nasync function cloneModelInternal(\n sourceURL: string, destURL: string,\n deleteSource = false): Promise {\n assert(\n sourceURL !== destURL,\n () => `Old path and new path are the same: '${sourceURL}'`);\n\n const loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);\n assert(\n loadHandlers.length > 0,\n () => `Copying failed because no load handler is found for source URL ${\n sourceURL}.`);\n assert(\n loadHandlers.length < 2,\n () => `Copying failed because more than one (${loadHandlers.length}) ` +\n `load handlers for source URL ${sourceURL}.`);\n const loadHandler = loadHandlers[0];\n\n const saveHandlers = IORouterRegistry.getSaveHandlers(destURL);\n assert(\n saveHandlers.length > 0,\n () => `Copying failed because no save handler is found for destination ` +\n `URL ${destURL}.`);\n assert(\n saveHandlers.length < 2,\n () => `Copying failed because more than one (${loadHandlers.length}) ` +\n `save handlers for destination URL ${destURL}.`);\n const saveHandler = saveHandlers[0];\n\n const sourceScheme = parseURL(sourceURL).scheme;\n const sourcePath = parseURL(sourceURL).path;\n const sameMedium = sourceScheme === parseURL(sourceURL).scheme;\n\n const modelArtifacts = await loadHandler.load();\n\n // If moving within the same storage medium, remove the old model as soon as\n // the loading is done. Without doing this, it is possible that the combined\n // size of the two models will cause the cloning to fail.\n if (deleteSource && sameMedium) {\n await ModelStoreManagerRegistry.getManager(sourceScheme)\n .removeModel(sourcePath);\n }\n\n const saveResult = await saveHandler.save(modelArtifacts);\n\n // If moving between mediums, the deletion is done after the save succeeds.\n // This guards against the case in which saving to the destination medium\n // fails.\n if (deleteSource && !sameMedium) {\n await ModelStoreManagerRegistry.getManager(sourceScheme)\n .removeModel(sourcePath);\n }\n\n return saveResult.modelArtifactsInfo;\n}\n\n/**\n * List all models stored in registered storage mediums.\n *\n * For a web browser environment, the registered mediums are Local Storage and\n * IndexedDB.\n *\n * ```js\n * // First create and save a model.\n * const model = tf.sequential();\n * model.add(tf.layers.dense(\n * {units: 1, inputShape: [10], activation: 'sigmoid'}));\n * await model.save('localstorage://demo/management/model1');\n *\n * // Then list existing models.\n * console.log(JSON.stringify(await tf.io.listModels()));\n *\n * // Delete the model.\n * await tf.io.removeModel('localstorage://demo/management/model1');\n *\n * // List models again.\n * console.log(JSON.stringify(await tf.io.listModels()));\n * ```\n *\n * @returns A `Promise` of a dictionary mapping URLs of existing models to\n * their model artifacts info. URLs include medium-specific schemes, e.g.,\n * 'indexeddb://my/model/1'. Model artifacts info include type of the\n * model's topology, byte sizes of the topology, weights, etc.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Management',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nasync function listModels(): Promise<{[url: string]: ModelArtifactsInfo}> {\n const schemes = ModelStoreManagerRegistry.getSchemes();\n const out: {[url: string]: ModelArtifactsInfo} = {};\n for (const scheme of schemes) {\n const schemeOut =\n await ModelStoreManagerRegistry.getManager(scheme).listModels();\n for (const path in schemeOut) {\n const url = scheme + URL_SCHEME_SUFFIX + path;\n out[url] = schemeOut[path];\n }\n }\n return out;\n}\n\n/**\n * Remove a model specified by URL from a reigstered storage medium.\n *\n * ```js\n * // First create and save a model.\n * const model = tf.sequential();\n * model.add(tf.layers.dense(\n * {units: 1, inputShape: [10], activation: 'sigmoid'}));\n * await model.save('localstorage://demo/management/model1');\n *\n * // Then list existing models.\n * console.log(JSON.stringify(await tf.io.listModels()));\n *\n * // Delete the model.\n * await tf.io.removeModel('localstorage://demo/management/model1');\n *\n * // List models again.\n * console.log(JSON.stringify(await tf.io.listModels()));\n * ```\n *\n * @param url A URL to a stored model, with a scheme prefix, e.g.,\n * 'localstorage://my-model-1', 'indexeddb://my/model/2'.\n * @returns ModelArtifactsInfo of the deleted model (if and only if deletion\n * is successful).\n * @throws Error if deletion fails, e.g., if no model exists at `path`.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Management',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nasync function removeModel(url: string): Promise {\n const schemeAndPath = parseURL(url);\n const manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);\n return manager.removeModel(schemeAndPath.path);\n}\n\n/**\n * Copy a model from one URL to another.\n *\n * This function supports:\n *\n * 1. Copying within a storage medium, e.g.,\n * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`\n * 2. Copying between two storage mediums, e.g.,\n * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`\n *\n * ```js\n * // First create and save a model.\n * const model = tf.sequential();\n * model.add(tf.layers.dense(\n * {units: 1, inputShape: [10], activation: 'sigmoid'}));\n * await model.save('localstorage://demo/management/model1');\n *\n * // Then list existing models.\n * console.log(JSON.stringify(await tf.io.listModels()));\n *\n * // Copy the model, from Local Storage to IndexedDB.\n * await tf.io.copyModel(\n * 'localstorage://demo/management/model1',\n * 'indexeddb://demo/management/model1');\n *\n * // List models again.\n * console.log(JSON.stringify(await tf.io.listModels()));\n *\n * // Remove both models.\n * await tf.io.removeModel('localstorage://demo/management/model1');\n * await tf.io.removeModel('indexeddb://demo/management/model1');\n * ```\n *\n * @param sourceURL Source URL of copying.\n * @param destURL Destination URL of copying.\n * @returns ModelArtifactsInfo of the copied model (if and only if copying\n * is successful).\n * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or\n * if `oldPath` and `newPath` are identical.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Management',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nasync function copyModel(\n sourceURL: string, destURL: string): Promise {\n const deleteSource = false;\n return cloneModelInternal(sourceURL, destURL, deleteSource);\n}\n\n/**\n * Move a model from one URL to another.\n *\n * This function supports:\n *\n * 1. Moving within a storage medium, e.g.,\n * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`\n * 2. Moving between two storage mediums, e.g.,\n * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`\n *\n * ```js\n * // First create and save a model.\n * const model = tf.sequential();\n * model.add(tf.layers.dense(\n * {units: 1, inputShape: [10], activation: 'sigmoid'}));\n * await model.save('localstorage://demo/management/model1');\n *\n * // Then list existing models.\n * console.log(JSON.stringify(await tf.io.listModels()));\n *\n * // Move the model, from Local Storage to IndexedDB.\n * await tf.io.moveModel(\n * 'localstorage://demo/management/model1',\n * 'indexeddb://demo/management/model1');\n *\n * // List models again.\n * console.log(JSON.stringify(await tf.io.listModels()));\n *\n * // Remove the moved model.\n * await tf.io.removeModel('indexeddb://demo/management/model1');\n * ```\n *\n * @param sourceURL Source URL of moving.\n * @param destURL Destination URL of moving.\n * @returns ModelArtifactsInfo of the copied model (if and only if copying\n * is successful).\n * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or\n * if `oldPath` and `newPath` are identical.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Management',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nasync function moveModel(\n sourceURL: string, destURL: string): Promise {\n const deleteSource = true;\n return cloneModelInternal(sourceURL, destURL, deleteSource);\n}\n\nexport {moveModel, copyModel, removeModel, listModels};\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../environment';\n\nimport {getModelArtifactsInfoForJSON} from './io_utils';\nimport {ModelStoreManagerRegistry} from './model_management';\nimport {IORouter, IORouterRegistry} from './router_registry';\nimport {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types';\n\nconst DATABASE_NAME = 'tensorflowjs';\nconst DATABASE_VERSION = 1;\n\n// Model data and ModelArtifactsInfo (metadata) are stored in two separate\n// stores for efficient access of the list of stored models and their metadata.\n// 1. The object store for model data: topology, weights and weight manifests.\nconst MODEL_STORE_NAME = 'models_store';\n// 2. The object store for ModelArtifactsInfo, including meta-information such\n// as the type of topology (JSON vs binary), byte size of the topology, byte\n// size of the weights, etc.\nconst INFO_STORE_NAME = 'model_info_store';\n\n/**\n * Delete the entire database for tensorflow.js, including the models store.\n */\nexport async function deleteDatabase(): Promise {\n const idbFactory = getIndexedDBFactory();\n\n return new Promise((resolve, reject) => {\n const deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME);\n deleteRequest.onsuccess = () => resolve();\n deleteRequest.onerror = error => reject(error);\n });\n}\n\nfunction getIndexedDBFactory(): IDBFactory {\n if (!env().getBool('IS_BROWSER')) {\n // TODO(cais): Add more info about what IOHandler subtypes are available.\n // Maybe point to a doc page on the web and/or automatically determine\n // the available IOHandlers and print them in the error message.\n throw new Error(\n 'Failed to obtain IndexedDB factory because the current environment' +\n 'is not a web browser.');\n }\n // tslint:disable-next-line:no-any\n const theWindow: any = window || self;\n const factory = theWindow.indexedDB || theWindow.mozIndexedDB ||\n theWindow.webkitIndexedDB || theWindow.msIndexedDB ||\n theWindow.shimIndexedDB;\n if (factory == null) {\n throw new Error(\n 'The current browser does not appear to support IndexedDB.');\n }\n return factory;\n}\n\nfunction setUpDatabase(openRequest: IDBRequest) {\n const db = openRequest.result as IDBDatabase;\n db.createObjectStore(MODEL_STORE_NAME, {keyPath: 'modelPath'});\n db.createObjectStore(INFO_STORE_NAME, {keyPath: 'modelPath'});\n}\n\n/**\n * IOHandler subclass: Browser IndexedDB.\n *\n * See the doc string of `browserIndexedDB` for more details.\n */\nexport class BrowserIndexedDB implements IOHandler {\n protected readonly indexedDB: IDBFactory;\n protected readonly modelPath: string;\n\n static readonly URL_SCHEME = 'indexeddb://';\n\n constructor(modelPath: string) {\n this.indexedDB = getIndexedDBFactory();\n\n if (modelPath == null || !modelPath) {\n throw new Error(\n 'For IndexedDB, modelPath must not be null, undefined or empty.');\n }\n this.modelPath = modelPath;\n }\n\n async save(modelArtifacts: ModelArtifacts): Promise {\n // TODO(cais): Support saving GraphDef models.\n if (modelArtifacts.modelTopology instanceof ArrayBuffer) {\n throw new Error(\n 'BrowserLocalStorage.save() does not support saving model topology ' +\n 'in binary formats yet.');\n }\n\n return this.databaseAction(this.modelPath, modelArtifacts) as\n Promise;\n }\n\n async load(): Promise {\n return this.databaseAction(this.modelPath) as Promise;\n }\n\n /**\n * Perform database action to put model artifacts into or read model artifacts\n * from IndexedDB object store.\n *\n * Whether the action is put or get depends on whether `modelArtifacts` is\n * specified. If it is specified, the action will be put; otherwise the action\n * will be get.\n *\n * @param modelPath A unique string path for the model.\n * @param modelArtifacts If specified, it will be the model artifacts to be\n * stored in IndexedDB.\n * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`\n * of `ModelArtifacts`, if the action is get.\n */\n private databaseAction(modelPath: string, modelArtifacts?: ModelArtifacts):\n Promise {\n return new Promise((resolve, reject) => {\n const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);\n openRequest.onupgradeneeded = () => setUpDatabase(openRequest);\n\n openRequest.onsuccess = () => {\n const db = openRequest.result;\n\n if (modelArtifacts == null) {\n // Read model out from object store.\n const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');\n const modelStore = modelTx.objectStore(MODEL_STORE_NAME);\n const getRequest = modelStore.get(this.modelPath);\n getRequest.onsuccess = () => {\n if (getRequest.result == null) {\n db.close();\n return reject(new Error(\n `Cannot find model with path '${this.modelPath}' ` +\n `in IndexedDB.`));\n } else {\n resolve(getRequest.result.modelArtifacts);\n }\n };\n getRequest.onerror = error => {\n db.close();\n return reject(getRequest.error);\n };\n modelTx.oncomplete = () => db.close();\n } else {\n // Put model into object store.\n const modelArtifactsInfo: ModelArtifactsInfo =\n getModelArtifactsInfoForJSON(modelArtifacts);\n // First, put ModelArtifactsInfo into info store.\n const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');\n let infoStore = infoTx.objectStore(INFO_STORE_NAME);\n const putInfoRequest =\n infoStore.put({modelPath: this.modelPath, modelArtifactsInfo});\n let modelTx: IDBTransaction;\n putInfoRequest.onsuccess = () => {\n // Second, put model data into model store.\n modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');\n const modelStore = modelTx.objectStore(MODEL_STORE_NAME);\n const putModelRequest = modelStore.put({\n modelPath: this.modelPath,\n modelArtifacts,\n modelArtifactsInfo\n });\n putModelRequest.onsuccess = () => resolve({modelArtifactsInfo});\n putModelRequest.onerror = error => {\n // If the put-model request fails, roll back the info entry as\n // well.\n infoStore = infoTx.objectStore(INFO_STORE_NAME);\n const deleteInfoRequest = infoStore.delete(this.modelPath);\n deleteInfoRequest.onsuccess = () => {\n db.close();\n return reject(putModelRequest.error);\n };\n deleteInfoRequest.onerror = error => {\n db.close();\n return reject(putModelRequest.error);\n };\n };\n };\n putInfoRequest.onerror = error => {\n db.close();\n return reject(putInfoRequest.error);\n };\n infoTx.oncomplete = () => {\n if (modelTx == null) {\n db.close();\n } else {\n modelTx.oncomplete = () => db.close();\n }\n };\n }\n };\n openRequest.onerror = error => reject(openRequest.error);\n });\n }\n}\n\nexport const indexedDBRouter: IORouter = (url: string|string[]) => {\n if (!env().getBool('IS_BROWSER')) {\n return null;\n } else {\n if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {\n return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));\n } else {\n return null;\n }\n }\n};\nIORouterRegistry.registerSaveRouter(indexedDBRouter);\nIORouterRegistry.registerLoadRouter(indexedDBRouter);\n\n/**\n * Creates a browser IndexedDB IOHandler for saving and loading models.\n *\n * ```js\n * const model = tf.sequential();\n * model.add(\n * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));\n *\n * const saveResult = await model.save('indexeddb://MyModel'));\n * console.log(saveResult);\n * ```\n *\n * @param modelPath A unique identifier for the model to be saved. Must be a\n * non-empty string.\n * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`),\n * which can be used with, e.g., `tf.Model.save`.\n */\nexport function browserIndexedDB(modelPath: string): IOHandler {\n return new BrowserIndexedDB(modelPath);\n}\n\nfunction maybeStripScheme(key: string) {\n return key.startsWith(BrowserIndexedDB.URL_SCHEME) ?\n key.slice(BrowserIndexedDB.URL_SCHEME.length) :\n key;\n}\n\nexport class BrowserIndexedDBManager implements ModelStoreManager {\n private indexedDB: IDBFactory;\n\n constructor() {\n this.indexedDB = getIndexedDBFactory();\n }\n\n async listModels(): Promise<{[path: string]: ModelArtifactsInfo}> {\n return new Promise<{[path: string]: ModelArtifactsInfo}>(\n (resolve, reject) => {\n const openRequest =\n this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);\n openRequest.onupgradeneeded = () => setUpDatabase(openRequest);\n\n openRequest.onsuccess = () => {\n const db = openRequest.result;\n const tx = db.transaction(INFO_STORE_NAME, 'readonly');\n const store = tx.objectStore(INFO_STORE_NAME);\n // tslint:disable:max-line-length\n // Need to cast `store` as `any` here because TypeScript's DOM\n // library does not have the `getAll()` method even though the\n // method is supported in the latest version of most mainstream\n // browsers:\n // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll\n // tslint:enable:max-line-length\n // tslint:disable-next-line:no-any\n const getAllInfoRequest = (store as any).getAll() as IDBRequest;\n getAllInfoRequest.onsuccess = () => {\n const out: {[path: string]: ModelArtifactsInfo} = {};\n for (const item of getAllInfoRequest.result) {\n out[item.modelPath] = item.modelArtifactsInfo;\n }\n resolve(out);\n };\n getAllInfoRequest.onerror = error => {\n db.close();\n return reject(getAllInfoRequest.error);\n };\n tx.oncomplete = () => db.close();\n };\n openRequest.onerror = error => reject(openRequest.error);\n });\n }\n\n async removeModel(path: string): Promise {\n path = maybeStripScheme(path);\n return new Promise((resolve, reject) => {\n const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);\n openRequest.onupgradeneeded = () => setUpDatabase(openRequest);\n\n openRequest.onsuccess = () => {\n const db = openRequest.result;\n const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');\n const infoStore = infoTx.objectStore(INFO_STORE_NAME);\n\n const getInfoRequest = infoStore.get(path);\n let modelTx: IDBTransaction;\n getInfoRequest.onsuccess = () => {\n if (getInfoRequest.result == null) {\n db.close();\n return reject(new Error(\n `Cannot find model with path '${path}' ` +\n `in IndexedDB.`));\n } else {\n // First, delete the entry in the info store.\n const deleteInfoRequest = infoStore.delete(path);\n const deleteModelData = () => {\n // Second, delete the entry in the model store.\n modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');\n const modelStore = modelTx.objectStore(MODEL_STORE_NAME);\n const deleteModelRequest = modelStore.delete(path);\n deleteModelRequest.onsuccess = () =>\n resolve(getInfoRequest.result.modelArtifactsInfo);\n deleteModelRequest.onerror = error =>\n reject(getInfoRequest.error);\n };\n // Proceed with deleting model data regardless of whether deletion\n // of info data succeeds or not.\n deleteInfoRequest.onsuccess = deleteModelData;\n deleteInfoRequest.onerror = error => {\n deleteModelData();\n db.close();\n return reject(getInfoRequest.error);\n };\n }\n };\n getInfoRequest.onerror = error => {\n db.close();\n return reject(getInfoRequest.error);\n };\n\n infoTx.oncomplete = () => {\n if (modelTx == null) {\n db.close();\n } else {\n modelTx.oncomplete = () => db.close();\n }\n };\n };\n openRequest.onerror = error => reject(openRequest.error);\n });\n }\n}\n\nif (env().getBool('IS_BROWSER')) {\n // Wrap the construction and registration, to guard against browsers that\n // don't support Local Storage.\n try {\n ModelStoreManagerRegistry.registerManager(\n BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());\n } catch (err) {\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../environment';\n\nimport {assert} from '../util';\nimport {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils';\nimport {ModelStoreManagerRegistry} from './model_management';\nimport {IORouter, IORouterRegistry} from './router_registry';\nimport {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types';\n\nconst PATH_SEPARATOR = '/';\nconst PATH_PREFIX = 'tensorflowjs_models';\nconst INFO_SUFFIX = 'info';\nconst MODEL_TOPOLOGY_SUFFIX = 'model_topology';\nconst WEIGHT_SPECS_SUFFIX = 'weight_specs';\nconst WEIGHT_DATA_SUFFIX = 'weight_data';\nconst MODEL_METADATA_SUFFIX = 'model_metadata';\n\n/**\n * Purge all tensorflow.js-saved model artifacts from local storage.\n *\n * @returns Paths of the models purged.\n */\nexport function purgeLocalStorageArtifacts(): string[] {\n if (!env().getBool('IS_BROWSER') ||\n typeof window === 'undefined' ||\n typeof window.localStorage === 'undefined') {\n throw new Error(\n 'purgeLocalStorageModels() cannot proceed because local storage is ' +\n 'unavailable in the current environment.');\n }\n const LS = window.localStorage;\n const purgedModelPaths: string[] = [];\n for (let i = 0; i < LS.length; ++i) {\n const key = LS.key(i);\n const prefix = PATH_PREFIX + PATH_SEPARATOR;\n if (key.startsWith(prefix) && key.length > prefix.length) {\n LS.removeItem(key);\n const modelName = getModelPathFromKey(key);\n if (purgedModelPaths.indexOf(modelName) === -1) {\n purgedModelPaths.push(modelName);\n }\n }\n }\n return purgedModelPaths;\n}\n\nfunction getModelKeys(path: string): {\n info: string,\n topology: string,\n weightSpecs: string,\n weightData: string,\n modelMetadata: string\n} {\n return {\n info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),\n topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),\n weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),\n weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),\n modelMetadata:\n [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)\n };\n}\n\n/**\n * Get model path from a local-storage key.\n *\n * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'\n *\n * @param key\n */\nfunction getModelPathFromKey(key: string) {\n const items = key.split(PATH_SEPARATOR);\n if (items.length < 3) {\n throw new Error(`Invalid key format: ${key}`);\n }\n return items.slice(1, items.length - 1).join(PATH_SEPARATOR);\n}\n\nfunction maybeStripScheme(key: string) {\n return key.startsWith(BrowserLocalStorage.URL_SCHEME) ?\n key.slice(BrowserLocalStorage.URL_SCHEME.length) :\n key;\n}\n\ndeclare type LocalStorageKeys = {\n info: string,\n topology: string,\n weightSpecs: string,\n weightData: string,\n modelMetadata: string\n};\n\n/**\n * IOHandler subclass: Browser Local Storage.\n *\n * See the doc string to `browserLocalStorage` for more details.\n */\nexport class BrowserLocalStorage implements IOHandler {\n protected readonly LS: Storage;\n protected readonly modelPath: string;\n protected readonly keys: LocalStorageKeys;\n\n static readonly URL_SCHEME = 'localstorage://';\n\n constructor(modelPath: string) {\n if (!env().getBool('IS_BROWSER') ||\n typeof window === 'undefined' ||\n typeof window.localStorage === 'undefined') {\n // TODO(cais): Add more info about what IOHandler subtypes are\n // available.\n // Maybe point to a doc page on the web and/or automatically determine\n // the available IOHandlers and print them in the error message.\n throw new Error(\n 'The current environment does not support local storage.');\n }\n this.LS = window.localStorage;\n\n if (modelPath == null || !modelPath) {\n throw new Error(\n 'For local storage, modelPath must not be null, undefined or empty.');\n }\n this.modelPath = modelPath;\n this.keys = getModelKeys(this.modelPath);\n }\n\n /**\n * Save model artifacts to browser local storage.\n *\n * See the documentation to `browserLocalStorage` for details on the saved\n * artifacts.\n *\n * @param modelArtifacts The model artifacts to be stored.\n * @returns An instance of SaveResult.\n */\n async save(modelArtifacts: ModelArtifacts): Promise {\n if (modelArtifacts.modelTopology instanceof ArrayBuffer) {\n throw new Error(\n 'BrowserLocalStorage.save() does not support saving model topology ' +\n 'in binary formats yet.');\n } else {\n const topology = JSON.stringify(modelArtifacts.modelTopology);\n const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);\n\n const modelArtifactsInfo: ModelArtifactsInfo =\n getModelArtifactsInfoForJSON(modelArtifacts);\n\n try {\n this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));\n this.LS.setItem(this.keys.topology, topology);\n this.LS.setItem(this.keys.weightSpecs, weightSpecs);\n this.LS.setItem(\n this.keys.weightData,\n arrayBufferToBase64String(modelArtifacts.weightData));\n this.LS.setItem(this.keys.modelMetadata, JSON.stringify({\n format: modelArtifacts.format,\n generatedBy: modelArtifacts.generatedBy,\n convertedBy: modelArtifacts.convertedBy,\n userDefinedMetadata: modelArtifacts.userDefinedMetadata\n }));\n\n return {modelArtifactsInfo};\n } catch (err) {\n // If saving failed, clean up all items saved so far.\n this.LS.removeItem(this.keys.info);\n this.LS.removeItem(this.keys.topology);\n this.LS.removeItem(this.keys.weightSpecs);\n this.LS.removeItem(this.keys.weightData);\n this.LS.removeItem(this.keys.modelMetadata);\n\n throw new Error(\n `Failed to save model '${this.modelPath}' to local storage: ` +\n `size quota being exceeded is a possible cause of this failure: ` +\n `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` +\n `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` +\n `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`);\n }\n }\n }\n\n /**\n * Load a model from local storage.\n *\n * See the documentation to `browserLocalStorage` for details on the saved\n * artifacts.\n *\n * @returns The loaded model (if loading succeeds).\n */\n async load(): Promise {\n const info =\n JSON.parse(this.LS.getItem(this.keys.info)) as ModelArtifactsInfo;\n if (info == null) {\n throw new Error(\n `In local storage, there is no model with name '${this.modelPath}'`);\n }\n\n if (info.modelTopologyType !== 'JSON') {\n throw new Error(\n 'BrowserLocalStorage does not support loading non-JSON model ' +\n 'topology yet.');\n }\n\n const out: ModelArtifacts = {};\n\n // Load topology.\n const topology = JSON.parse(this.LS.getItem(this.keys.topology));\n if (topology == null) {\n throw new Error(\n `In local storage, the topology of model '${this.modelPath}' ` +\n `is missing.`);\n }\n out.modelTopology = topology;\n\n // Load weight specs.\n const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));\n if (weightSpecs == null) {\n throw new Error(\n `In local storage, the weight specs of model '${this.modelPath}' ` +\n `are missing.`);\n }\n out.weightSpecs = weightSpecs;\n\n // Load meta-data fields.\n const metadataString = this.LS.getItem(this.keys.modelMetadata);\n if (metadataString != null) {\n const metadata = JSON.parse(metadataString) as ModelArtifacts;\n out.format = metadata['format'];\n out.generatedBy = metadata['generatedBy'];\n out.convertedBy = metadata['convertedBy'];\n out.userDefinedMetadata = metadata['userDefinedMetadata'];\n }\n\n // Load weight data.\n const weightDataBase64 = this.LS.getItem(this.keys.weightData);\n if (weightDataBase64 == null) {\n throw new Error(\n `In local storage, the binary weight values of model ` +\n `'${this.modelPath}' are missing.`);\n }\n out.weightData = base64StringToArrayBuffer(weightDataBase64);\n\n return out;\n }\n}\n\nexport const localStorageRouter: IORouter = (url: string|string[]) => {\n if (!env().getBool('IS_BROWSER')) {\n return null;\n } else {\n if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {\n return browserLocalStorage(\n url.slice(BrowserLocalStorage.URL_SCHEME.length));\n } else {\n return null;\n }\n }\n};\nIORouterRegistry.registerSaveRouter(localStorageRouter);\nIORouterRegistry.registerLoadRouter(localStorageRouter);\n\n/**\n * Factory function for local storage IOHandler.\n *\n * This `IOHandler` supports both `save` and `load`.\n *\n * For each model's saved artifacts, four items are saved to local storage.\n * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the\n * model, such as date saved, type of the topology, size in bytes, etc.\n * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-\n * style models, this is a stringized JSON.\n * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the\n * model, can be used to decode the saved binary weight values (see\n * item below).\n * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary\n * weight values, stored as a base64-encoded string.\n *\n * Saving may throw an `Error` if the total size of the artifacts exceed the\n * browser-specific quota.\n *\n * @param modelPath A unique identifier for the model to be saved. Must be a\n * non-empty string.\n * @returns An instance of `IOHandler`, which can be used with, e.g.,\n * `tf.Model.save`.\n */\nexport function browserLocalStorage(modelPath: string): IOHandler {\n return new BrowserLocalStorage(modelPath);\n}\n\nexport class BrowserLocalStorageManager implements ModelStoreManager {\n private readonly LS: Storage;\n\n constructor() {\n assert(\n env().getBool('IS_BROWSER'),\n () => 'Current environment is not a web browser');\n assert(\n typeof window === 'undefined' ||\n typeof window.localStorage !== 'undefined',\n () => 'Current browser does not appear to support localStorage');\n this.LS = window.localStorage;\n }\n\n async listModels(): Promise<{[path: string]: ModelArtifactsInfo}> {\n const out: {[path: string]: ModelArtifactsInfo} = {};\n const prefix = PATH_PREFIX + PATH_SEPARATOR;\n const suffix = PATH_SEPARATOR + INFO_SUFFIX;\n for (let i = 0; i < this.LS.length; ++i) {\n const key = this.LS.key(i);\n if (key.startsWith(prefix) && key.endsWith(suffix)) {\n const modelPath = getModelPathFromKey(key);\n out[modelPath] = JSON.parse(this.LS.getItem(key)) as ModelArtifactsInfo;\n }\n }\n return out;\n }\n\n async removeModel(path: string): Promise {\n path = maybeStripScheme(path);\n const keys = getModelKeys(path);\n if (this.LS.getItem(keys.info) == null) {\n throw new Error(`Cannot find model at path '${path}'`);\n }\n const info = JSON.parse(this.LS.getItem(keys.info)) as ModelArtifactsInfo;\n\n this.LS.removeItem(keys.info);\n this.LS.removeItem(keys.topology);\n this.LS.removeItem(keys.weightSpecs);\n this.LS.removeItem(keys.weightData);\n return info;\n }\n}\n\nif (env().getBool('IS_BROWSER')) {\n // Wrap the construction and registration, to guard against browsers that\n // don't support Local Storage.\n try {\n ModelStoreManagerRegistry.registerManager(\n BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());\n } catch (err) {\n }\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * IOHandlers related to files, such as browser-triggered file downloads,\n * user-selected files in browser.\n */\n\nimport {env} from '../environment';\n\nimport {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';\nimport {IORouter, IORouterRegistry} from './router_registry';\nimport {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';\n\nconst DEFAULT_FILE_NAME_PREFIX = 'model';\nconst DEFAULT_JSON_EXTENSION_NAME = '.json';\nconst DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';\n\nfunction defer(f: () => T): Promise {\n return new Promise(resolve => setTimeout(resolve)).then(f);\n}\n\nexport class BrowserDownloads implements IOHandler {\n private readonly modelTopologyFileName: string;\n private readonly weightDataFileName: string;\n private readonly jsonAnchor: HTMLAnchorElement;\n private readonly weightDataAnchor: HTMLAnchorElement;\n\n static readonly URL_SCHEME = 'downloads://';\n\n constructor(fileNamePrefix?: string) {\n if (!env().getBool('IS_BROWSER')) {\n // TODO(cais): Provide info on what IOHandlers are available under the\n // current environment.\n throw new Error(\n 'browserDownloads() cannot proceed because the current environment ' +\n 'is not a browser.');\n }\n\n if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {\n fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);\n }\n if (fileNamePrefix == null || fileNamePrefix.length === 0) {\n fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;\n }\n\n this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;\n this.weightDataFileName =\n fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;\n }\n\n async save(modelArtifacts: ModelArtifacts): Promise {\n if (typeof (document) === 'undefined') {\n throw new Error(\n 'Browser downloads are not supported in ' +\n 'this environment since `document` is not present');\n }\n const weightsURL = window.URL.createObjectURL(new Blob(\n [modelArtifacts.weightData], {type: 'application/octet-stream'}));\n\n if (modelArtifacts.modelTopology instanceof ArrayBuffer) {\n throw new Error(\n 'BrowserDownloads.save() does not support saving model topology ' +\n 'in binary formats yet.');\n } else {\n const weightsManifest: WeightsManifestConfig = [{\n paths: ['./' + this.weightDataFileName],\n weights: modelArtifacts.weightSpecs\n }];\n const modelTopologyAndWeightManifest: ModelJSON = {\n modelTopology: modelArtifacts.modelTopology,\n format: modelArtifacts.format,\n generatedBy: modelArtifacts.generatedBy,\n convertedBy: modelArtifacts.convertedBy,\n weightsManifest\n };\n const modelTopologyAndWeightManifestURL =\n window.URL.createObjectURL(new Blob(\n [JSON.stringify(modelTopologyAndWeightManifest)],\n {type: 'application/json'}));\n\n // If anchor elements are not provided, create them without attaching them\n // to parents, so that the downloaded file names can be controlled.\n const jsonAnchor = this.jsonAnchor == null ? document.createElement('a') :\n this.jsonAnchor;\n jsonAnchor.download = this.modelTopologyFileName;\n jsonAnchor.href = modelTopologyAndWeightManifestURL;\n // Trigger downloads by evoking a click event on the download anchors.\n // When multiple downloads are started synchronously, Firefox will only\n // save the last one.\n await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));\n\n if (modelArtifacts.weightData != null) {\n const weightDataAnchor = this.weightDataAnchor == null ?\n document.createElement('a') :\n this.weightDataAnchor;\n weightDataAnchor.download = this.weightDataFileName;\n weightDataAnchor.href = weightsURL;\n await defer(\n () => weightDataAnchor.dispatchEvent(new MouseEvent('click')));\n }\n\n return {modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)};\n }\n }\n}\n\nclass BrowserFiles implements IOHandler {\n private readonly files: File[];\n\n constructor(files: File[]) {\n if (files == null || files.length < 1) {\n throw new Error(\n `When calling browserFiles, at least 1 file is required, ` +\n `but received ${files}`);\n }\n this.files = files;\n }\n\n async load(): Promise {\n const jsonFile = this.files[0];\n const weightFiles = this.files.slice(1);\n\n return new Promise((resolve, reject) => {\n const jsonReader = new FileReader();\n jsonReader.onload = (event: Event) => {\n // tslint:disable-next-line:no-any\n const modelJSON = JSON.parse((event.target as any).result) as ModelJSON;\n const modelTopology = modelJSON.modelTopology;\n if (modelTopology == null) {\n reject(new Error(\n `modelTopology field is missing from file ${jsonFile.name}`));\n return;\n }\n\n if (weightFiles.length === 0) {\n resolve({modelTopology});\n }\n\n const weightsManifest = modelJSON.weightsManifest;\n if (weightsManifest == null) {\n reject(new Error(\n `weightManifest field is missing from file ${jsonFile.name}`));\n return;\n }\n\n let pathToFile: {[path: string]: File};\n try {\n pathToFile =\n this.checkManifestAndWeightFiles(weightsManifest, weightFiles);\n } catch (err) {\n reject(err);\n return;\n }\n\n const weightSpecs: WeightsManifestEntry[] = [];\n const paths: string[] = [];\n const perFileBuffers: ArrayBuffer[] = [];\n weightsManifest.forEach(weightsGroup => {\n weightsGroup.paths.forEach(path => {\n paths.push(path);\n perFileBuffers.push(null);\n });\n weightSpecs.push(...weightsGroup.weights);\n });\n\n weightsManifest.forEach(weightsGroup => {\n weightsGroup.paths.forEach(path => {\n const weightFileReader = new FileReader();\n weightFileReader.onload = (event: Event) => {\n // tslint:disable-next-line:no-any\n const weightData = (event.target as any).result as ArrayBuffer;\n const index = paths.indexOf(path);\n perFileBuffers[index] = weightData;\n if (perFileBuffers.indexOf(null) === -1) {\n resolve({\n modelTopology,\n weightSpecs,\n weightData: concatenateArrayBuffers(perFileBuffers),\n format: modelJSON.format,\n generatedBy: modelJSON.generatedBy,\n convertedBy: modelJSON.convertedBy,\n userDefinedMetadata: modelJSON.userDefinedMetadata\n });\n }\n };\n weightFileReader.onerror = error =>\n reject(`Failed to weights data from file of path '${path}'.`);\n weightFileReader.readAsArrayBuffer(pathToFile[path]);\n });\n });\n };\n jsonReader.onerror = error => reject(\n `Failed to read model topology and weights manifest JSON ` +\n `from file '${jsonFile.name}'. BrowserFiles supports loading ` +\n `Keras-style tf.Model artifacts only.`);\n jsonReader.readAsText(jsonFile);\n });\n }\n\n /**\n * Check the compatibility between weights manifest and weight files.\n */\n private checkManifestAndWeightFiles(\n manifest: WeightsManifestConfig, files: File[]): {[path: string]: File} {\n const basenames: string[] = [];\n const fileNames = files.map(file => basename(file.name));\n const pathToFile: {[path: string]: File} = {};\n for (const group of manifest) {\n group.paths.forEach(path => {\n const pathBasename = basename(path);\n if (basenames.indexOf(pathBasename) !== -1) {\n throw new Error(\n `Duplicate file basename found in weights manifest: ` +\n `'${pathBasename}'`);\n }\n basenames.push(pathBasename);\n if (fileNames.indexOf(pathBasename) === -1) {\n throw new Error(\n `Weight file with basename '${pathBasename}' is not provided.`);\n } else {\n pathToFile[path] = files[fileNames.indexOf(pathBasename)];\n }\n });\n }\n\n if (basenames.length !== files.length) {\n throw new Error(\n `Mismatch in the number of files in weights manifest ` +\n `(${basenames.length}) and the number of weight files provided ` +\n `(${files.length}).`);\n }\n return pathToFile;\n }\n}\n\nexport const browserDownloadsRouter: IORouter = (url: string|string[]) => {\n if (!env().getBool('IS_BROWSER')) {\n return null;\n } else {\n if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {\n return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));\n } else {\n return null;\n }\n }\n};\nIORouterRegistry.registerSaveRouter(browserDownloadsRouter);\n\n/**\n * Creates an IOHandler that triggers file downloads from the browser.\n *\n * The returned `IOHandler` instance can be used as model exporting methods such\n * as `tf.Model.save` and supports only saving.\n *\n * ```js\n * const model = tf.sequential();\n * model.add(tf.layers.dense(\n * {units: 1, inputShape: [10], activation: 'sigmoid'}));\n * const saveResult = await model.save('downloads://mymodel');\n * // This will trigger downloading of two files:\n * // 'mymodel.json' and 'mymodel.weights.bin'.\n * console.log(saveResult);\n * ```\n *\n * @param fileNamePrefix Prefix name of the files to be downloaded. For use with\n * `tf.Model`, `fileNamePrefix` should follow either of the following two\n * formats:\n * 1. `null` or `undefined`, in which case the default file\n * names will be used:\n * - 'model.json' for the JSON file containing the model topology and\n * weights manifest.\n * - 'model.weights.bin' for the binary file containing the binary weight\n * values.\n * 2. A single string or an Array of a single string, as the file name prefix.\n * For example, if `'foo'` is provided, the downloaded JSON\n * file and binary weights file will be named 'foo.json' and\n * 'foo.weights.bin', respectively.\n * @param config Additional configuration for triggering downloads.\n * @returns An instance of `BrowserDownloads` `IOHandler`.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Loading',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nexport function browserDownloads(fileNamePrefix = 'model'): IOHandler {\n return new BrowserDownloads(fileNamePrefix);\n}\n\n/**\n * Creates an IOHandler that loads model artifacts from user-selected files.\n *\n * This method can be used for loading from files such as user-selected files\n * in the browser.\n * When used in conjunction with `tf.loadLayersModel`, an instance of\n * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.\n *\n * ```js\n * // Note: This code snippet won't run properly without the actual file input\n * // elements in the HTML DOM.\n *\n * // Suppose there are two HTML file input (``)\n * // elements.\n * const uploadJSONInput = document.getElementById('upload-json');\n * const uploadWeightsInput = document.getElementById('upload-weights');\n * const model = await tf.loadLayersModel(tf.io.browserFiles(\n * [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));\n * ```\n *\n * @param files `File`s to load from. Currently, this function supports only\n * loading from files that contain Keras-style models (i.e., `tf.Model`s), for\n * which an `Array` of `File`s is expected (in that order):\n * - A JSON file containing the model topology and weight manifest.\n * - Optionally, One or more binary files containing the binary weights.\n * These files must have names that match the paths in the `weightsManifest`\n * contained by the aforementioned JSON file, or errors will be thrown\n * during loading. These weights files have the same format as the ones\n * generated by `tensorflowjs_converter` that comes with the `tensorflowjs`\n * Python PIP package. If no weights files are provided, only the model\n * topology will be loaded from the JSON file above.\n * @returns An instance of `Files` `IOHandler`.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Loading',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nexport function browserFiles(files: File[]): IOHandler {\n return new BrowserFiles(files);\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {assert} from '../util';\n\nimport {OnProgressCallback} from './types';\n\n/**\n * Monitor Promise.all progress, fire onProgress callback function.\n *\n * @param promises Promise list going to be monitored\n * @param onProgress Callback function. Fired when a promise resolved.\n * @param startFraction Optional fraction start. Default to 0.\n * @param endFraction Optional fraction end. Default to 1.\n */\nexport function monitorPromisesProgress(\n promises: Array>, onProgress: OnProgressCallback,\n startFraction?: number, endFraction?: number) {\n checkPromises(promises);\n startFraction = startFraction == null ? 0 : startFraction;\n endFraction = endFraction == null ? 1 : endFraction;\n checkFraction(startFraction, endFraction);\n let resolvedPromise = 0;\n\n const registerMonitor = (promise: Promise<{}>) => {\n promise.then(value => {\n const fraction = startFraction +\n ++resolvedPromise / promises.length * (endFraction - startFraction);\n // pass fraction as parameter to callback function.\n onProgress(fraction);\n return value;\n });\n return promise;\n };\n\n function checkPromises(promises: Array>): void {\n assert(\n promises != null && Array.isArray(promises) && promises.length > 0,\n () => 'promises must be a none empty array');\n }\n\n function checkFraction(startFraction: number, endFraction: number): void {\n assert(\n startFraction >= 0 && startFraction <= 1,\n () => `Progress fraction must be in range [0, 1], but ` +\n `got startFraction ${startFraction}`);\n assert(\n endFraction >= 0 && endFraction <= 1,\n () => `Progress fraction must be in range [0, 1], but ` +\n `got endFraction ${endFraction}`);\n assert(\n endFraction >= startFraction,\n () => `startFraction must be no more than endFraction, but ` +\n `got startFraction ${startFraction} and endFraction ` +\n `${endFraction}`);\n }\n\n return Promise.all(promises.map(registerMonitor));\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../environment';\n\nimport {NamedTensorMap} from '../tensor_types';\nimport * as util from '../util';\nimport {decodeWeights} from './io_utils';\nimport {monitorPromisesProgress} from './progress';\nimport {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifestEntry} from './types';\n\n/**\n * Reads binary weights data from a number of URLs.\n *\n * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.\n * @param requestOptions RequestInit (options) for the HTTP requests.\n * @param fetchFunc Optional overriding value for the `window.fetch` function.\n * @param onProgress Optional, progress callback function, fired periodically\n * before the load is completed.\n * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same\n * length as `fetchURLs`.\n */\nexport async function loadWeightsAsArrayBuffer(\n fetchURLs: string[], loadOptions?: LoadOptions): Promise {\n if (loadOptions == null) {\n loadOptions = {};\n }\n\n const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :\n loadOptions.fetchFunc;\n\n // Create the requests for all of the weights in parallel.\n const requests = fetchURLs.map(\n fetchURL =>\n fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true}));\n\n const fetchStartFraction = 0;\n const fetchEndFraction = 0.5;\n\n const responses = loadOptions.onProgress == null ?\n await Promise.all(requests) :\n await monitorPromisesProgress(\n requests, loadOptions.onProgress, fetchStartFraction,\n fetchEndFraction);\n\n const bufferPromises = responses.map(response => response.arrayBuffer());\n\n const bufferStartFraction = 0.5;\n const bufferEndFraction = 1;\n\n const buffers = loadOptions.onProgress == null ?\n await Promise.all(bufferPromises) :\n await monitorPromisesProgress(\n bufferPromises, loadOptions.onProgress, bufferStartFraction,\n bufferEndFraction);\n return buffers;\n}\n\n/**\n * Reads a weights manifest JSON configuration, fetches the weights and\n * returns them as `Tensor`s.\n *\n * @param manifest The weights manifest JSON.\n * @param filePathPrefix The path prefix for filenames given in the manifest.\n * Defaults to the empty string.\n * @param weightNames The names of the weights to be fetched.\n */\nexport async function loadWeights(\n manifest: WeightsManifestConfig, filePathPrefix = '',\n weightNames?: string[],\n requestInit?: RequestInit): Promise {\n // TODO(nsthorat): Groups are currently fetched atomically. If you need a\n // single weight from a group, the whole group will be fetched. At a future\n // date, we should support fetching only the individual shards within a\n // group that are needed to reconstruct the requested weight.\n // TODO(cais): Use `decodeWeights` for implementation.\n\n const fetchWeights = (fetchUrls: string[]) =>\n loadWeightsAsArrayBuffer(fetchUrls, {requestInit});\n const loadWeights = weightsLoaderFactory(fetchWeights);\n\n return loadWeights(manifest, filePathPrefix, weightNames);\n}\n\n/**\n * Creates a function, which reads a weights manifest JSON configuration,\n * fetches the weight files using the specified function and returns them as\n * `Tensor`s.\n *\n * ```js\n * // example for creating a nodejs weight loader, which reads the weight files\n * // from disk using fs.readFileSync\n *\n * import * as fs from 'fs'\n *\n * const fetchWeightsFromDisk = (filePaths: string[]) =>\n * filePaths.map(filePath => fs.readFileSync(filePath).buffer)\n *\n * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)\n *\n * const manifest = JSON.parse(\n * fs.readFileSync('./my_model-weights_manifest').toString()\n * )\n * const weightMap = await loadWeights(manifest, './')\n * ```\n * @param fetchWeightsFunction The function used for fetching the weight files.\n * @returns Weight loading function.\n */\nexport function weightsLoaderFactory(\n fetchWeightsFunction: (fetchUrls: string[]) => Promise):\n (manifest: WeightsManifestConfig, filePathPrefix?: string,\n weightNames?: string[]) => Promise {\n return async(\n manifest: WeightsManifestConfig, filePathPrefix = '',\n weightNames?: string[]): Promise => {\n // Collect all the groups, weights, and their relative offsets to be\n // fetched.\n const groupIndicesToFetchMap = manifest.map(() => false);\n const groupWeightsToFetch: {\n [group: number]: Array<{\n manifestEntry: WeightsManifestEntry; groupOffset: number;\n sizeBytes: number;\n }>\n } = {};\n const weightsFound =\n weightNames != null ? weightNames.map(() => false) : [];\n const allManifestWeightNames: string[] = [];\n manifest.forEach((manifestGroupConfig, groupIndex) => {\n let groupOffset = 0;\n manifestGroupConfig.weights.forEach(weightsEntry => {\n const rawDtype = ('quantization' in weightsEntry) ?\n weightsEntry.quantization.dtype :\n weightsEntry.dtype;\n\n const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *\n util.sizeFromShape(weightsEntry.shape);\n\n const enqueueWeightsForFetchingFn = () => {\n groupIndicesToFetchMap[groupIndex] = true;\n if (groupWeightsToFetch[groupIndex] == null) {\n groupWeightsToFetch[groupIndex] = [];\n }\n\n groupWeightsToFetch[groupIndex].push({\n manifestEntry: weightsEntry,\n groupOffset,\n sizeBytes: weightsBytes\n });\n };\n\n if (weightNames != null) {\n weightNames.forEach((weightName, weightIndex) => {\n if (weightName === weightsEntry.name) {\n enqueueWeightsForFetchingFn();\n weightsFound[weightIndex] = true;\n }\n });\n } else {\n enqueueWeightsForFetchingFn();\n }\n\n allManifestWeightNames.push(weightsEntry.name);\n groupOffset += weightsBytes;\n });\n });\n\n if (!weightsFound.every(found => found)) {\n const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);\n throw new Error(\n `Could not find weights in manifest with names: ` +\n `${weightsNotFound.join(', ')}. \\n` +\n `Manifest JSON has weights with names: ` +\n `${allManifestWeightNames.join(', ')}.`);\n }\n\n // Convert the one-hot boolean groupId => shouldFetch map to a list of group\n // IDs.\n const groupIndicesToFetch =\n groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {\n if (shouldFetch) {\n accumulator.push(i);\n }\n return accumulator;\n }, []);\n\n const fetchUrls: string[] = [];\n groupIndicesToFetch.forEach(i => {\n manifest[i].paths.forEach(filepath => {\n const fetchUrl = filePathPrefix +\n (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;\n fetchUrls.push(fetchUrl);\n });\n });\n const buffers = await fetchWeightsFunction(fetchUrls);\n\n const weightsTensorMap: NamedTensorMap = {};\n let bufferIndexOffset = 0;\n groupIndicesToFetch.forEach(i => {\n const numBuffers = manifest[i].paths.length;\n\n let groupBytes = 0;\n for (let i = 0; i < numBuffers; i++) {\n groupBytes += buffers[bufferIndexOffset + i].byteLength;\n }\n\n // Create a buffer for the whole group.\n const groupBuffer = new ArrayBuffer(groupBytes);\n const groupByteBuffer = new Uint8Array(groupBuffer);\n let groupBufferOffset = 0;\n for (let i = 0; i < numBuffers; i++) {\n const buffer = new Uint8Array(buffers[bufferIndexOffset + i]);\n groupByteBuffer.set(buffer, groupBufferOffset);\n groupBufferOffset += buffer.byteLength;\n }\n\n const weightsEntries = groupWeightsToFetch[i];\n weightsEntries.forEach(weightsEntry => {\n const byteBuffer = groupBuffer.slice(\n weightsEntry.groupOffset,\n weightsEntry.groupOffset + weightsEntry.sizeBytes);\n const nameToTensorMap =\n decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);\n for (const name in nameToTensorMap) {\n weightsTensorMap[name] = nameToTensorMap[name];\n }\n });\n\n bufferIndexOffset += numBuffers;\n });\n\n return weightsTensorMap;\n };\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * IOHandler implementations based on HTTP requests in the web browser.\n *\n * Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).\n */\n\nimport {env} from '../environment';\n\nimport {assert} from '../util';\nimport {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';\nimport {IORouter, IORouterRegistry} from './router_registry';\nimport {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';\nimport {loadWeightsAsArrayBuffer} from './weights_loader';\n\nconst OCTET_STREAM_MIME_TYPE = 'application/octet-stream';\nconst JSON_TYPE = 'application/json';\nexport class HTTPRequest implements IOHandler {\n protected readonly path: string;\n protected readonly requestInit: RequestInit;\n\n private readonly fetch: Function;\n\n readonly DEFAULT_METHOD = 'POST';\n\n static readonly URL_SCHEME_REGEX = /^https?:\\/\\//;\n\n private readonly weightPathPrefix: string;\n private readonly onProgress: OnProgressCallback;\n\n constructor(path: string, loadOptions?: LoadOptions) {\n if (loadOptions == null) {\n loadOptions = {};\n }\n this.weightPathPrefix = loadOptions.weightPathPrefix;\n this.onProgress = loadOptions.onProgress;\n\n if (loadOptions.fetchFunc != null) {\n assert(\n typeof loadOptions.fetchFunc === 'function',\n () => 'Must pass a function that matches the signature of ' +\n '`fetch` (see ' +\n 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');\n this.fetch = loadOptions.fetchFunc;\n } else {\n this.fetch = env().platform.fetch;\n }\n\n assert(\n path != null && path.length > 0,\n () => 'URL path for http must not be null, undefined or ' +\n 'empty.');\n\n if (Array.isArray(path)) {\n assert(\n path.length === 2,\n () => 'URL paths for http must have a length of 2, ' +\n `(actual length is ${path.length}).`);\n }\n this.path = path;\n\n if (loadOptions.requestInit != null &&\n loadOptions.requestInit.body != null) {\n throw new Error(\n 'requestInit is expected to have no pre-existing body, but has one.');\n }\n this.requestInit = loadOptions.requestInit || {};\n }\n\n async save(modelArtifacts: ModelArtifacts): Promise {\n if (modelArtifacts.modelTopology instanceof ArrayBuffer) {\n throw new Error(\n 'BrowserHTTPRequest.save() does not support saving model topology ' +\n 'in binary formats yet.');\n }\n\n const init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit);\n init.body = new FormData();\n\n const weightsManifest: WeightsManifestConfig = [{\n paths: ['./model.weights.bin'],\n weights: modelArtifacts.weightSpecs,\n }];\n const modelTopologyAndWeightManifest: ModelJSON = {\n modelTopology: modelArtifacts.modelTopology,\n format: modelArtifacts.format,\n generatedBy: modelArtifacts.generatedBy,\n convertedBy: modelArtifacts.convertedBy,\n userDefinedMetadata: modelArtifacts.userDefinedMetadata,\n weightsManifest\n };\n\n init.body.append(\n 'model.json',\n new Blob(\n [JSON.stringify(modelTopologyAndWeightManifest)],\n {type: JSON_TYPE}),\n 'model.json');\n\n if (modelArtifacts.weightData != null) {\n init.body.append(\n 'model.weights.bin',\n new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}),\n 'model.weights.bin');\n }\n\n const response = await this.fetch(this.path, init);\n\n if (response.ok) {\n return {\n modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),\n responses: [response],\n };\n } else {\n throw new Error(\n `BrowserHTTPRequest.save() failed due to HTTP response status ` +\n `${response.status}.`);\n }\n }\n\n /**\n * Load model artifacts via HTTP request(s).\n *\n * See the documentation to `tf.io.http` for details on the saved\n * artifacts.\n *\n * @returns The loaded model artifacts (if loading succeeds).\n */\n async load(): Promise {\n const modelConfigRequest = await this.fetch(this.path, this.requestInit);\n\n if (!modelConfigRequest.ok) {\n throw new Error(\n `Request to ${this.path} failed with status code ` +\n `${modelConfigRequest.status}. Please verify this URL points to ` +\n `the model JSON of the model to load.`);\n }\n let modelConfig: ModelJSON;\n try {\n modelConfig = await modelConfigRequest.json();\n } catch (e) {\n let message = `Failed to parse model JSON of response from ${this.path}.`;\n // TODO(nsthorat): Remove this after some time when we're comfortable that\n // .pb files are mostly gone.\n if (this.path.endsWith('.pb')) {\n message += ' Your path contains a .pb file extension. ' +\n 'Support for .pb models have been removed in TensorFlow.js 1.0 ' +\n 'in favor of .json models. You can re-convert your Python ' +\n 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +\n 'or you can convert your.pb models with the \\'pb2json\\'' +\n 'NPM script in the tensorflow/tfjs-converter repository.';\n } else {\n message += ' Please make sure the server is serving valid ' +\n 'JSON for this request.';\n }\n throw new Error(message);\n }\n const modelTopology = modelConfig.modelTopology;\n const weightsManifest = modelConfig.weightsManifest;\n const generatedBy = modelConfig.generatedBy;\n const convertedBy = modelConfig.convertedBy;\n const format = modelConfig.format;\n const userDefinedMetadata = modelConfig.userDefinedMetadata;\n\n // We do not allow both modelTopology and weightsManifest to be missing.\n if (modelTopology == null && weightsManifest == null) {\n throw new Error(\n `The JSON from HTTP path ${this.path} contains neither model ` +\n `topology or manifest for weights.`);\n }\n\n let weightSpecs: WeightsManifestEntry[];\n let weightData: ArrayBuffer;\n if (weightsManifest != null) {\n const results = await this.loadWeights(weightsManifest);\n [weightSpecs, weightData] = results;\n }\n\n return {\n modelTopology,\n weightSpecs,\n weightData,\n userDefinedMetadata,\n generatedBy,\n convertedBy,\n format\n };\n }\n\n private async loadWeights(weightsManifest: WeightsManifestConfig):\n Promise<[WeightsManifestEntry[], ArrayBuffer]> {\n const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;\n const [prefix, suffix] = parseUrl(weightPath);\n const pathPrefix = this.weightPathPrefix || prefix;\n\n const weightSpecs = [];\n for (const entry of weightsManifest) {\n weightSpecs.push(...entry.weights);\n }\n\n const fetchURLs: string[] = [];\n weightsManifest.forEach(weightsGroup => {\n weightsGroup.paths.forEach(path => {\n fetchURLs.push(pathPrefix + path + suffix);\n });\n });\n const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {\n requestInit: this.requestInit,\n fetchFunc: this.fetch,\n onProgress: this.onProgress\n });\n return [weightSpecs, concatenateArrayBuffers(buffers)];\n }\n}\n\n/**\n * Extract the prefix and suffix of the url, where the prefix is the path before\n * the last file, and suffix is the search params after the last file.\n * ```\n * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'\n * [prefix, suffix] = parseUrl(url)\n * // prefix = 'http://tfhub.dev/model/1/'\n * // suffix = '?tfjs-format=file'\n * ```\n * @param url the model url to be parsed.\n */\nexport function parseUrl(url: string): [string, string] {\n const lastSlash = url.lastIndexOf('/');\n const lastSearchParam = url.lastIndexOf('?');\n const prefix = url.substring(0, lastSlash);\n const suffix =\n lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';\n return [prefix + '/', suffix];\n}\n\nexport function isHTTPScheme(url: string): boolean {\n return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;\n}\n\nexport const httpRouter: IORouter =\n (url: string, onProgress?: OnProgressCallback) => {\n if (typeof fetch === 'undefined') {\n // `http` uses `fetch` or `node-fetch`, if one wants to use it in\n // an environment that is not the browser or node they have to setup a\n // global fetch polyfill.\n return null;\n } else {\n let isHTTP = true;\n if (Array.isArray(url)) {\n isHTTP = url.every(urlItem => isHTTPScheme(urlItem));\n } else {\n isHTTP = isHTTPScheme(url);\n }\n if (isHTTP) {\n return http(url, {onProgress});\n }\n }\n return null;\n };\nIORouterRegistry.registerSaveRouter(httpRouter);\nIORouterRegistry.registerLoadRouter(httpRouter);\n\n/**\n * Creates an IOHandler subtype that sends model artifacts to HTTP server.\n *\n * An HTTP request of the `multipart/form-data` mime type will be sent to the\n * `path` URL. The form data includes artifacts that represent the topology\n * and/or weights of the model. In the case of Keras-style `tf.Model`, two\n * blobs (files) exist in form-data:\n * - A JSON file consisting of `modelTopology` and `weightsManifest`.\n * - A binary weights file consisting of the concatenated weight values.\n * These files are in the same format as the one generated by\n * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).\n *\n * The following code snippet exemplifies the client-side code that uses this\n * function:\n *\n * ```js\n * const model = tf.sequential();\n * model.add(\n * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));\n *\n * const saveResult = await model.save(tf.io.http(\n * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));\n * console.log(saveResult);\n * ```\n *\n * If the default `POST` method is to be used, without any custom parameters\n * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:\n *\n * ```js\n * const saveResult = await model.save('http://model-server:5000/upload');\n * ```\n *\n * The following GitHub Gist\n * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864\n * implements a server based on [flask](https://github.com/pallets/flask) that\n * can receive the request. Upon receiving the model artifacts via the requst,\n * this particular server reconsistutes instances of [Keras\n * Models](https://keras.io/models/model/) in memory.\n *\n *\n * @param path A URL path to the model.\n * Can be an absolute HTTP path (e.g.,\n * 'http://localhost:8000/model-upload)') or a relative path (e.g.,\n * './model-upload').\n * @param requestInit Request configurations to be used when sending\n * HTTP request to server using `fetch`. It can contain fields such as\n * `method`, `credentials`, `headers`, `mode`, etc. See\n * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request\n * for more information. `requestInit` must not have a body, because the\n * body will be set by TensorFlow.js. File blobs representing the model\n * topology (filename: 'model.json') and the weights of the model (filename:\n * 'model.weights.bin') will be appended to the body. If `requestInit` has a\n * `body`, an Error will be thrown.\n * @param loadOptions Optional configuration for the loading. It includes the\n * following fields:\n * - weightPathPrefix Optional, this specifies the path prefix for weight\n * files, by default this is calculated from the path param.\n * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,\n * the `fetch` from node-fetch can be used here.\n * - onProgress Optional, progress callback function, fired periodically\n * before the load is completed.\n * @returns An instance of `IOHandler`.\n */\n/**\n * @doc {\n * heading: 'Models',\n * subheading: 'Loading',\n * namespace: 'io',\n * ignoreCI: true\n * }\n */\nexport function http(path: string, loadOptions?: LoadOptions): IOHandler {\n return new HTTPRequest(path, loadOptions);\n}\n\n/**\n * Deprecated. Use `tf.io.http`.\n * @param path\n * @param loadOptions\n */\nexport function browserHTTPRequest(\n path: string, loadOptions?: LoadOptions): IOHandler {\n return http(path, loadOptions);\n}\n","/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * IOHandlers that pass through the in-memory ModelArtifacts format.\n */\n\nimport {IOHandler, ModelArtifacts, SaveResult, TrainingConfig, WeightsManifestEntry} from './types';\n\nclass PassthroughLoader implements IOHandler {\n constructor(private readonly modelArtifacts?: ModelArtifacts) {}\n\n async load(): Promise {\n return this.modelArtifacts;\n }\n}\n\nclass PassthroughSaver implements IOHandler {\n constructor(\n private readonly saveHandler:\n (artifacts: ModelArtifacts) => Promise) {}\n\n async save(modelArtifacts: ModelArtifacts) {\n return this.saveHandler(modelArtifacts);\n }\n}\n\n/**\n * Creates an IOHandler that loads model artifacts from memory.\n *\n * When used in conjunction with `tf.loadLayersModel`, an instance of\n * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.\n *\n * ```js\n * const model = await tf.loadLayersModel(tf.io.fromMemory(\n * modelTopology, weightSpecs, weightData));\n * ```\n *\n * @param modelArtifacts a object containing model topology (i.e., parsed from\n * the JSON format).\n * @param weightSpecs An array of `WeightsManifestEntry` objects describing the\n * names, shapes, types, and quantization of the weight data.\n * @param weightData A single `ArrayBuffer` containing the weight data,\n * concatenated in the order described by the weightSpecs.\n * @param trainingConfig Model training configuration. Optional.\n *\n * @returns A passthrough `IOHandler` that simply loads the provided data.\n */\nexport function fromMemory(\n modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[],\n weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler {\n if (arguments.length === 1) {\n const isModelArtifacts =\n (modelArtifacts as ModelArtifacts).modelTopology != null ||\n (modelArtifacts as ModelArtifacts).weightSpecs != null;\n if (isModelArtifacts) {\n return new PassthroughLoader(modelArtifacts as ModelArtifacts);\n } else {\n // Legacy support: with only modelTopology.\n // TODO(cais): Remove this deprecated API.\n console.warn(\n 'Please call tf.io.fromMemory() with only one argument. ' +\n 'The argument should be of type ModelArtifacts. ' +\n 'The multi-argument signature of tf.io.fromMemory() has been ' +\n 'deprecated and will be removed in a future release.');\n return new PassthroughLoader({modelTopology: modelArtifacts as {}});\n }\n } else {\n // Legacy support.\n // TODO(cais): Remove this deprecated API.\n console.warn(\n 'Please call tf.io.fromMemory() with only one argument. ' +\n 'The argument should be of type ModelArtifacts. ' +\n 'The multi-argument signature of tf.io.fromMemory() has been ' +\n 'deprecated and will be removed in a future release.');\n return new PassthroughLoader({\n modelTopology: modelArtifacts as {},\n weightSpecs,\n weightData,\n trainingConfig\n });\n }\n}\n\n/**\n * Creates an IOHandler that passes saved model artifacts to a callback.\n *\n * ```js\n * function handleSave(artifacts) {\n * // ... do something with the artifacts ...\n * return {modelArtifactsInfo: {...}, ...};\n * }\n *\n * const saveResult = model.save(tf.io.withSaveHandler(handleSave));\n * ```\n *\n * @param saveHandler A function that accepts a `ModelArtifacts` and returns a\n * `SaveResult`.\n */\nexport function withSaveHandler(\n saveHandler: (artifacts: ModelArtifacts) =>\n Promise): IOHandler {\n return new PassthroughSaver(saveHandler);\n}\n","/**\n * @license\n * Copyright 2019 Google Inc. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {getKernel} from '../kernel_registry';\nimport {Tensor, Tensor2D, Tensor3D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {PixelData, TensorLike} from '../types';\n\nimport {op} from './operation';\nimport {tensor3d} from './tensor_ops';\n\nlet fromPixels2DContext: CanvasRenderingContext2D;\n\n/**\n * Creates a `tf.Tensor` from an image.\n *\n * ```js\n * const image = new ImageData(1, 1);\n * image.data[0] = 100;\n * image.data[1] = 150;\n * image.data[2] = 200;\n * image.data[3] = 255;\n *\n * tf.browser.fromPixels(image).print();\n * ```\n *\n * @param pixels The input image to construct the tensor from. The\n * supported image types are all 4-channel. You can also pass in an image\n * object with following attributes:\n * `{data: Uint8Array; width: number; height: number}`\n * @param numChannels The number of channels of the output tensor. A\n * numChannels value less than 4 allows you to ignore channels. Defaults to\n * 3 (ignores alpha channel of input image).\n */\n/** @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true} */\nfunction fromPixels_(\n pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|\n HTMLVideoElement,\n numChannels = 3): Tensor3D {\n // Sanity checks.\n if (numChannels > 4) {\n throw new Error(\n 'Cannot construct Tensor with more than 4 channels from pixels.');\n }\n if (pixels == null) {\n throw new Error('pixels passed to tf.browser.fromPixels() can not be null');\n }\n let isPixelData = false;\n let isImageData = false;\n let isVideo = false;\n let isImage = false;\n let isCanvasLike = false;\n if ((pixels as PixelData).data instanceof Uint8Array) {\n isPixelData = true;\n } else if (\n typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) {\n isImageData = true;\n } else if (\n typeof (HTMLVideoElement) !== 'undefined' &&\n pixels instanceof HTMLVideoElement) {\n isVideo = true;\n } else if (\n typeof (HTMLImageElement) !== 'undefined' &&\n pixels instanceof HTMLImageElement) {\n isImage = true;\n // tslint:disable-next-line: no-any\n } else if ((pixels as any).getContext != null) {\n isCanvasLike = true;\n } else {\n throw new Error(\n 'pixels passed to tf.browser.fromPixels() must be either an ' +\n `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +\n `in browser, or OffscreenCanvas, ImageData in webworker` +\n ` or {data: Uint32Array, width: number, height: number}, ` +\n `but was ${(pixels as {}).constructor.name}`);\n }\n if (isVideo) {\n const HAVE_CURRENT_DATA_READY_STATE = 2;\n if (isVideo &&\n (pixels as HTMLVideoElement).readyState <\n HAVE_CURRENT_DATA_READY_STATE) {\n throw new Error(\n 'The video element has not loaded data yet. Please wait for ' +\n '`loadeddata` event on the