You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
291 lines
10 KiB
291 lines
10 KiB
"""Zip primitive used by the GEMM function.
|
|
|
|
Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to
|
|
multiply of 8 length with zeros. Calculates row sums and appends those at the
|
|
end.
|
|
"""
|
|
|
|
import neon_emitter
|
|
|
|
|
|
class Error(Exception):
|
|
"""Module level error."""
|
|
|
|
|
|
class ConfigurationError(Error):
|
|
"""Unsupported configuration."""
|
|
|
|
|
|
class ZipLane(object):
|
|
|
|
def __init__(self, input_address, load, aggregator):
|
|
self.input_address = input_address
|
|
self.load = load
|
|
self.aggregator = aggregator
|
|
|
|
|
|
def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride):
|
|
"""Prepares read lanes for the zip operation.
|
|
|
|
Args:
|
|
emitter: ARM/NEON emitter.
|
|
registers: ARM/NEON registers state.
|
|
zip_lanes: number of lanes to prepare.
|
|
input_address: register that contains the input address for the first lane.
|
|
stride: memory stride for lane inputs.
|
|
|
|
Returns:
|
|
Array of ZipLane objects.
|
|
"""
|
|
lanes = []
|
|
last_address_register = input_address
|
|
for i in range(0, zip_lanes):
|
|
if not i:
|
|
lanes.append(ZipLane(input_address, registers.DoubleRegister(),
|
|
registers.QuadRegister(2)))
|
|
else:
|
|
address_register = registers.GeneralRegister()
|
|
lanes.append(ZipLane(address_register, registers.DoubleRegister(),
|
|
registers.QuadRegister(2)))
|
|
emitter.EmitAdd(address_register, last_address_register, stride)
|
|
last_address_register = address_register
|
|
return lanes
|
|
|
|
|
|
def BuildName(zip_lanes, leftovers, aligned):
|
|
name = 'zip_%dx8' % zip_lanes
|
|
if leftovers:
|
|
name += '_%d' % leftovers
|
|
if aligned:
|
|
name += '_aligned'
|
|
return name
|
|
|
|
|
|
def GenerateClearAggregators(emitter, lanes):
|
|
for lane in lanes:
|
|
emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0))
|
|
|
|
|
|
def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment):
|
|
"""Emit inner loop code for reading N lanes and interweaving them."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Load Aggregate Store.')
|
|
|
|
for lane in lanes:
|
|
emitter.EmitVLoad(
|
|
'1.8', lane.load,
|
|
emitter.DereferenceIncrement(lane.input_address, alignment))
|
|
|
|
store_registers = []
|
|
for lane in lanes:
|
|
emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
|
|
store_registers.append(lane.load)
|
|
|
|
emitter.EmitVStoreA('1.8', store_registers,
|
|
emitter.DereferenceIncrement(output_address, 64))
|
|
|
|
|
|
def GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes,
|
|
output_address):
|
|
"""Handle leftovers when count is not a multiply of 8."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Leftover Load Aggregate Store.')
|
|
|
|
# Clear load registers.
|
|
for lane in lanes:
|
|
emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0))
|
|
|
|
if leftovers == 1:
|
|
# Load 8 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0),
|
|
emitter.Dereference(lane.input_address, None))
|
|
elif leftovers == 2:
|
|
# Load 16 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
|
|
emitter.Dereference(lane.input_address, None))
|
|
elif leftovers == 3:
|
|
# Load 16 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
|
|
emitter.DereferenceIncrement(lane.input_address, None))
|
|
# Load 8 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2),
|
|
emitter.Dereference(lane.input_address, None))
|
|
elif leftovers == 4:
|
|
# Load 32 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
|
|
emitter.Dereference(lane.input_address, None))
|
|
elif leftovers == 5:
|
|
# Load 32 bits..
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
|
|
emitter.DereferenceIncrement(lane.input_address, None))
|
|
# Load 8 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4),
|
|
emitter.Dereference(lane.input_address, None))
|
|
elif leftovers == 6:
|
|
# Load 32 bits..
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
|
|
emitter.DereferenceIncrement(lane.input_address, None))
|
|
# Load 16 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
|
|
emitter.Dereference(lane.input_address, None))
|
|
elif leftovers == 7:
|
|
# Load 32 bits..
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
|
|
emitter.DereferenceIncrement(lane.input_address, None))
|
|
# Load 16 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
|
|
emitter.DereferenceIncrement(lane.input_address, None))
|
|
# Load 8 bits.
|
|
for lane in lanes:
|
|
emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6),
|
|
emitter.Dereference(lane.input_address, None))
|
|
else:
|
|
raise ConfigurationError('Unsupported leftover num: %d' % leftovers)
|
|
|
|
# Aggregate.
|
|
store_registers = []
|
|
for lane in lanes:
|
|
emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
|
|
store_registers.append(lane.load)
|
|
|
|
# Store.
|
|
emitter.EmitVStoreA('1.8', store_registers,
|
|
emitter.DereferenceIncrement(output_address, 64))
|
|
|
|
|
|
def GenerateAggregatorReduction(emitter, registers, lanes, output_address,
|
|
multiplicative_offset, additive_offset):
|
|
"""Reduce 4 lane sum aggregators to 1 value and store the sums."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Aggregator Reduction.')
|
|
|
|
multiplier = registers.DoubleRegister()
|
|
emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset)
|
|
offset = registers.QuadRegister()
|
|
emitter.EmitVDup('32', offset, additive_offset)
|
|
|
|
lane_temps = []
|
|
for lane in lanes:
|
|
emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator)
|
|
|
|
for lane in lanes:
|
|
lane_temp = registers.DoubleRegister()
|
|
lane_temps.append(lane_temp)
|
|
emitter.EmitVPadd('u32', lane_temp, registers.Low(lane.aggregator),
|
|
registers.High(lane.aggregator))
|
|
|
|
temp = registers.QuadRegister()
|
|
low = registers.Low(temp)
|
|
high = registers.High(temp)
|
|
|
|
if len(lanes) == 1:
|
|
emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0])
|
|
elif len(lanes) == 2:
|
|
emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
|
|
elif len(lanes) == 3:
|
|
emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
|
|
emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2])
|
|
elif len(lanes) == 4:
|
|
emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
|
|
emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3])
|
|
else:
|
|
raise ConfigurationError('Unexpected number of aggregators to reduce: %d' %
|
|
len(lanes))
|
|
|
|
emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0))
|
|
emitter.EmitVAdd('i32', temp, temp, offset)
|
|
|
|
if len(lanes) == 1:
|
|
emitter.EmitVStore('1.32', emitter.Lane(low, 0),
|
|
emitter.Dereference(output_address, None))
|
|
elif len(lanes) == 2:
|
|
emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64))
|
|
elif len(lanes) == 3:
|
|
emitter.EmitVStore('1.32', low,
|
|
emitter.DereferenceIncrement(output_address, 64))
|
|
emitter.EmitVStore('1.32', emitter.Lane(high, 0),
|
|
emitter.Dereference(output_address, None))
|
|
elif len(lanes) == 4:
|
|
emitter.EmitVStoreA('1.32', [low, high],
|
|
emitter.DereferenceIncrement(output_address, 64))
|
|
|
|
|
|
def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned):
|
|
"""Emit the zip function for a given number of rows and row size leftovers."""
|
|
if leftovers < 0 or leftovers > 7:
|
|
raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
|
|
if zip_lanes < 1 or zip_lanes > 4:
|
|
raise ConfigurationError('Zip_lanes should should be 1, 2, 3 or 4.')
|
|
|
|
name = BuildName(zip_lanes, leftovers, aligned)
|
|
|
|
emitter.EmitFunctionBeginA(
|
|
name, [['const std::uint8_t*', 'source'], ['std::int32_t', 'count'],
|
|
['std::int32_t', 'stride'], ['std::uint8_t*', 'destination'],
|
|
['std::int32_t', 'multiplicative_offset'],
|
|
['std::int32_t', 'additive_offset']], 'void')
|
|
emitter.EmitAssert('count %% 8 == %d' % leftovers)
|
|
emitter.EmitAssert('count <= 2048')
|
|
emitter.EmitAssert('count >= 8')
|
|
emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
|
|
if aligned:
|
|
emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
|
|
if zip_lanes > 1:
|
|
emitter.EmitAssert('stride % 8 == 0')
|
|
emitter.EmitAsmBegin()
|
|
|
|
registers = neon_emitter.NeonRegisters()
|
|
|
|
count = registers.MapParameter('count')
|
|
output_address = registers.MapParameter('destination')
|
|
|
|
lanes = GenerateZipLanes(emitter, registers, zip_lanes,
|
|
registers.MapParameter('source'),
|
|
registers.MapParameter('stride'))
|
|
|
|
if leftovers:
|
|
emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers))
|
|
|
|
GenerateClearAggregators(emitter, lanes)
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
|
|
GenerateLoadAggregateStore(emitter, lanes, output_address, 64 if aligned else
|
|
None)
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitBneBack(1)
|
|
|
|
if leftovers:
|
|
GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes,
|
|
output_address)
|
|
|
|
GenerateAggregatorReduction(emitter, registers, lanes, output_address,
|
|
registers.MapParameter('multiplicative_offset'),
|
|
registers.MapParameter('additive_offset'))
|
|
|
|
emitter.EmitAsmEnd(registers.MappedParameters(), [],
|
|
registers.Clobbers() + ['cc', 'memory'])
|
|
emitter.EmitFunctionEnd()
|
|
|
|
|
|
def GenerateFunctions(emitter):
|
|
for aligned in [True, False]:
|
|
for lanes in range(1, 5):
|
|
for leftovers in range(0, 8):
|
|
GenerateZipNx8(emitter, lanes, leftovers, aligned)
|
|
emitter.EmitNewline()
|