#define PAD_MODE_CONSTANT 0 #define PAD_MODE_EDGE 1 #define PAD_MODE_REFLECT 2 #define PAD_MODE_WRAP 3 #use guardAgainstOutOfBoundsWorkgroupSizes #use getElementAt #use .offsetToIndices .setByOffset .rank #param dim_value_zero #param is_float16 #param pad_mode $MAIN { guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size); let constant_value = #if is_float16 bitcast>(uniforms.constant_value)[0]; #else bitcast(uniforms.constant_value); #endif #if dim_value_zero output[global_idx] = constant_value; #else let output_indices = output.offsetToIndices(global_idx); var input_index = u32(0); var use_pad_value = false; var in_coord = i32(0); for (var dim = 0; dim < output.rank && !use_pad_value; dim++) { let output_index = i32(getElementAt(output_indices, dim, output.rank)); let lower_pads = getElementAt(uniforms.lower_pads, dim, output.rank); let data_shape = i32(getElementAt(uniforms.data_shape, dim, output.rank)); #if pad_mode == PAD_MODE_CONSTANT if (output_index < lower_pads || output_index >= data_shape + lower_pads) { use_pad_value = true; #elif pad_mode == PAD_MODE_EDGE if (output_index < lower_pads) { in_coord = 0; } else if (output_index >= data_shape + lower_pads) { in_coord = data_shape - 1; #elif pad_mode == PAD_MODE_REFLECT if (output_index < lower_pads || output_index >= data_shape + lower_pads) { in_coord = output_index - lower_pads; if (in_coord < 0) { in_coord = -in_coord; } let _2n_1 = 2 * (data_shape - 1); in_coord = in_coord % _2n_1; if (in_coord >= data_shape) { in_coord = _2n_1 - in_coord; } #else // PAD_MODE_WRAP if (output_index < lower_pads) { in_coord = data_shape + output_index - lower_pads; } else if (output_index >= data_shape + lower_pads) { in_coord = output_index - data_shape - lower_pads; #endif // pad_mode } else { in_coord = output_index - lower_pads; } input_index += select(u32(in_coord) #if output.rank > 1 * getElementAt(uniforms.data_stride, dim, output.rank - 1) #endif , u32(in_coord), dim == output.rank - 1); } output.setByOffset(global_idx, select(data[input_index], constant_value, use_pad_value)); #endif } // MAIN