You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
842 lines
30 KiB
842 lines
30 KiB
# Copyright 2016 The Gemmlowp Authors. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""32bit ARM/NEON assembly emitter.
|
|
|
|
Used by code generators to produce ARM assembly with NEON simd code.
|
|
Provides tools for easier register management: named register variable
|
|
allocation/deallocation, and offers a more procedural/structured approach
|
|
to generating assembly.
|
|
|
|
TODO: right now neon emitter prints out assembly instructions immediately,
|
|
it might be beneficial to keep the whole structure and emit the assembly after
|
|
applying some optimizations like: instruction reordering or register reuse.
|
|
|
|
TODO: NeonRegister object assigns explicit registers at allocation time.
|
|
Similarily to emiting code, register mapping and reuse can be performed and
|
|
optimized lazily.
|
|
"""
|
|
|
|
|
|
class Error(Exception):
|
|
"""Module level error."""
|
|
|
|
|
|
class RegisterAllocationError(Error):
|
|
"""Cannot alocate registers."""
|
|
|
|
|
|
class LaneError(Error):
|
|
"""Wrong lane number."""
|
|
|
|
|
|
class ArgumentError(Error):
|
|
"""Wrong argument."""
|
|
|
|
|
|
def _Low(register):
|
|
assert register[0] == 'q'
|
|
num = int(register[1:])
|
|
return 'd%d' % (num * 2)
|
|
|
|
|
|
def _High(register):
|
|
assert register[0] == 'q'
|
|
num = int(register[1:])
|
|
return 'd%d' % (num * 2 + 1)
|
|
|
|
|
|
def _ExpandQuads(registers):
|
|
doubles = []
|
|
for register in registers:
|
|
if register[0] == 'q':
|
|
doubles.append(_Low(register))
|
|
doubles.append(_High(register))
|
|
else:
|
|
doubles.append(register)
|
|
return doubles
|
|
|
|
|
|
def _MakeCompatible(op1, op2, op3):
|
|
if op1[0] == 'd' or op2[0] == 'd' or op3[0] == 'd':
|
|
if op1[0] == 'q':
|
|
op1 = _Low(op1)
|
|
if op2[0] == 'q':
|
|
op2 = _Low(op2)
|
|
if op3[0] == 'q':
|
|
op3 = _Low(op3)
|
|
return (op1, op2, op3)
|
|
|
|
|
|
class _NeonRegisters32Bit(object):
|
|
"""Utility that keeps track of used 32bit ARM/NEON registers."""
|
|
|
|
def __init__(self):
|
|
self.double = set()
|
|
self.double_ever = set()
|
|
self.general = set()
|
|
self.general_ever = set()
|
|
self.parameters = dict()
|
|
self.output_parameters = dict()
|
|
|
|
def MapParameter(self, parameter, parameter_value=None):
|
|
if not parameter_value:
|
|
parameter_value = parameter
|
|
self.parameters[parameter] = (parameter_value, 'r')
|
|
return '%%[%s]' % parameter
|
|
|
|
def MapMemoryParameter(self, parameter, parameter_value=None):
|
|
if not parameter_value:
|
|
parameter_value = parameter
|
|
self.parameters[parameter] = (parameter_value, 'm')
|
|
return '%%[%s]' % parameter
|
|
|
|
def MapOutputParameter(self, parameter, parameter_value=None):
|
|
if not parameter_value:
|
|
parameter_value = parameter
|
|
self.output_parameters[parameter] = (parameter_value, '+r')
|
|
return '%%[%s]' % parameter
|
|
|
|
def DoubleRegister(self, min_val=0):
|
|
for i in range(min_val, 32):
|
|
if i not in self.double:
|
|
self.double.add(i)
|
|
self.double_ever.add(i)
|
|
return 'd%d' % i
|
|
raise RegisterAllocationError('Not enough double registers.')
|
|
|
|
def QuadRegister(self, min_val=0):
|
|
for i in range(min_val, 16):
|
|
if ((i * 2) not in self.double) and ((i * 2 + 1) not in self.double):
|
|
self.double.add(i * 2)
|
|
self.double.add(i * 2 + 1)
|
|
self.double_ever.add(i * 2)
|
|
self.double_ever.add(i * 2 + 1)
|
|
return 'q%d' % i
|
|
raise RegisterAllocationError('Not enough quad registers.')
|
|
|
|
def GeneralRegister(self):
|
|
for i in range(0, 16):
|
|
if i not in self.general:
|
|
self.general.add(i)
|
|
self.general_ever.add(i)
|
|
return 'r%d' % i
|
|
raise RegisterAllocationError('Not enough general registers.')
|
|
|
|
def MappedParameters(self):
|
|
return [(k, v) for (k, v) in self.parameters.items()]
|
|
|
|
def MappedOutputParameters(self):
|
|
return [(k, v) for (k, v) in self.output_parameters.items()]
|
|
|
|
def Clobbers(self):
|
|
return (['r%d' % i for i in self.general_ever] +
|
|
['d%d' % i for i in self.DoubleClobbers()])
|
|
|
|
def DoubleClobbers(self):
|
|
return sorted(self.double_ever)
|
|
|
|
def FreeRegister(self, register):
|
|
assert len(register) > 1
|
|
if register[0] not in ['r', 'd', 'q']:
|
|
return
|
|
|
|
num = int(register[1:])
|
|
|
|
if register[0] == 'r':
|
|
assert num in self.general
|
|
self.general.remove(num)
|
|
elif register[0] == 'd':
|
|
assert num in self.double
|
|
self.double.remove(num)
|
|
elif register[0] == 'q':
|
|
assert num * 2 in self.double
|
|
assert num * 2 + 1 in self.double
|
|
self.double.remove(num * 2)
|
|
self.double.remove(num * 2 + 1)
|
|
else:
|
|
raise RegisterDeallocationError('Register not allocated: %s' % register)
|
|
|
|
def FreeRegisters(self, registers):
|
|
for register in registers:
|
|
self.FreeRegister(register)
|
|
|
|
|
|
class NeonEmitter(object):
|
|
"""Emits ARM/NEON assembly opcodes."""
|
|
|
|
def __init__(self, debug=False):
|
|
self.ops = {}
|
|
self.indent = ''
|
|
self.debug = debug
|
|
|
|
def PushIndent(self, delta=' '):
|
|
self.indent += delta
|
|
|
|
def PopIndent(self, delta=2):
|
|
self.indent = self.indent[:-delta]
|
|
|
|
def EmitIndented(self, what):
|
|
print(self.indent + what)
|
|
|
|
def PushOp(self, op):
|
|
if op in self.ops.keys():
|
|
self.ops[op] += 1
|
|
else:
|
|
self.ops[op] = 1
|
|
|
|
def ClearCounters(self):
|
|
self.ops.clear()
|
|
|
|
def EmitNewline(self):
|
|
print('')
|
|
|
|
def EmitPreprocessor1(self, op, param):
|
|
print('#%s %s' % (op, param))
|
|
|
|
def EmitPreprocessor(self, op):
|
|
print('#%s' % op)
|
|
|
|
def EmitInclude(self, include):
|
|
self.EmitPreprocessor1('include', include)
|
|
|
|
def EmitCall1(self, function, param):
|
|
self.EmitIndented('%s(%s);' % (function, param))
|
|
|
|
def EmitAssert(self, assert_expression):
|
|
if self.debug:
|
|
self.EmitCall1('assert', assert_expression)
|
|
|
|
def EmitHeaderBegin(self, header_name, includes):
|
|
self.EmitPreprocessor1('ifndef', (header_name + '_H_').upper())
|
|
self.EmitPreprocessor1('define', (header_name + '_H_').upper())
|
|
self.EmitNewline()
|
|
if includes:
|
|
for include in includes:
|
|
self.EmitInclude(include)
|
|
self.EmitNewline()
|
|
|
|
def EmitHeaderEnd(self):
|
|
self.EmitPreprocessor('endif')
|
|
|
|
def EmitCode(self, code):
|
|
self.EmitIndented('%s;' % code)
|
|
|
|
def EmitFunctionBeginA(self, function_name, params, return_type):
|
|
self.EmitIndented('%s %s(%s) {' %
|
|
(return_type, function_name,
|
|
', '.join(['%s %s' % (t, n) for (t, n) in params])))
|
|
self.PushIndent()
|
|
|
|
def EmitFunctionEnd(self):
|
|
self.PopIndent()
|
|
self.EmitIndented('}')
|
|
|
|
def EmitAsmBegin(self):
|
|
self.EmitIndented('asm volatile(')
|
|
self.PushIndent()
|
|
|
|
def EmitAsmMapping(self, elements):
|
|
if elements:
|
|
self.EmitIndented(': ' + ', '.join(
|
|
['[%s] "%s"(%s)' % (d, v[1], v[0]) for (d, v) in elements]))
|
|
else:
|
|
self.EmitIndented(':')
|
|
|
|
def EmitClobbers(self, elements):
|
|
if elements:
|
|
self.EmitIndented(': ' + ', '.join(['"%s"' % c for c in elements]))
|
|
else:
|
|
self.EmitIndented(':')
|
|
|
|
def EmitAsmEnd(self, registers):
|
|
self.EmitAsmMapping(registers.MappedOutputParameters())
|
|
self.EmitAsmMapping(registers.MappedParameters())
|
|
self.EmitClobbers(registers.Clobbers() + ['cc', 'memory'])
|
|
self.PopIndent()
|
|
self.EmitIndented(');')
|
|
|
|
def EmitComment(self, comment):
|
|
self.EmitIndented('// ' + comment)
|
|
|
|
def EmitNumericalLabel(self, label):
|
|
self.EmitIndented('"%d:"' % label)
|
|
|
|
def EmitOp1(self, op, param1):
|
|
self.PushOp(op)
|
|
self.EmitIndented('"%s %s\\n"' % (op, param1))
|
|
|
|
def EmitOp2(self, op, param1, param2):
|
|
self.PushOp(op)
|
|
self.EmitIndented('"%s %s, %s\\n"' % (op, param1, param2))
|
|
|
|
def EmitOp3(self, op, param1, param2, param3):
|
|
self.PushOp(op)
|
|
self.EmitIndented('"%s %s, %s, %s\\n"' % (op, param1, param2, param3))
|
|
|
|
def EmitAdd(self, destination, source, param):
|
|
self.EmitOp3('add', destination, source, param)
|
|
|
|
def EmitSubs(self, destination, source, param):
|
|
self.EmitOp3('subs', destination, source, param)
|
|
|
|
def EmitSub(self, destination, source, param):
|
|
self.EmitOp3('sub', destination, source, param)
|
|
|
|
def EmitMul(self, destination, source, param):
|
|
self.EmitOp3('mul', destination, source, param)
|
|
|
|
def EmitMov(self, param1, param2):
|
|
self.EmitOp2('mov', param1, param2)
|
|
|
|
def EmitBeqBack(self, label):
|
|
self.EmitOp1('beq', '%db' % label)
|
|
|
|
def EmitBeqFront(self, label):
|
|
self.EmitOp1('beq', '%df' % label)
|
|
|
|
def EmitBgtBack(self, label):
|
|
self.EmitOp1('bgt', '%db' % label)
|
|
|
|
def EmitBgtFront(self, label):
|
|
self.EmitOp1('bgt', '%df' % label)
|
|
|
|
def EmitBleBack(self, label):
|
|
self.EmitOp1('ble', '%db' % label)
|
|
|
|
def EmitBleFront(self, label):
|
|
self.EmitOp1('ble', '%df' % label)
|
|
|
|
def EmitBneBack(self, label):
|
|
self.EmitOp1('bne', '%db' % label)
|
|
|
|
def EmitBneFront(self, label):
|
|
self.EmitOp1('bne', '%df' % label)
|
|
|
|
def EmitVAdd(self, add_type, destination, source_1, source_2):
|
|
destination, source_1, source_2 = _MakeCompatible(destination, source_1,
|
|
source_2)
|
|
self.EmitOp3('vadd.%s' % add_type, destination, source_1, source_2)
|
|
|
|
def EmitVAddw(self, add_type, destination, source_1, source_2):
|
|
self.EmitOp3('vaddw.%s' % add_type, destination, source_1, source_2)
|
|
|
|
def EmitVSub(self, sub_type, destination, source_1, source_2):
|
|
destination, source_1, source_2 = _MakeCompatible(destination, source_1,
|
|
source_2)
|
|
self.EmitOp3('vsub.%s' % sub_type, destination, source_1, source_2)
|
|
|
|
def EmitVCvt(self, cvt_to, cvt_from, destination, source):
|
|
self.EmitOp2('vcvt.%s.%s' % (cvt_to, cvt_from), destination, source)
|
|
|
|
def EmitVDup(self, dup_type, destination, source):
|
|
self.EmitOp2('vdup.%s' % dup_type, destination, source)
|
|
|
|
def EmitVMax(self, size, destination, source_1, source_2):
|
|
self.EmitOp3('vmax.%s' % size, destination, source_1, source_2)
|
|
|
|
def EmitVMin(self, size, destination, source_1, source_2):
|
|
self.EmitOp3('vmin.%s' % size, destination, source_1, source_2)
|
|
|
|
def EmitVMov(self, mov_type, destination, source):
|
|
self.EmitOp2('vmov.%s' % mov_type, destination, source)
|
|
|
|
def EmitVMovl(self, mov_type, destination, source):
|
|
if source[0] == 'q':
|
|
source = _Low(source)
|
|
self.EmitOp2('vmovl.%s' % mov_type, destination, source)
|
|
|
|
def EmitVMovl2(self, mov_type, destination_1, destination_2, source):
|
|
self.EmitVMovl(mov_type, destination_2, _High(source))
|
|
self.EmitVMovl(mov_type, destination_1, _Low(source))
|
|
|
|
def EmitVQmovn(self, mov_type, destination, source):
|
|
if destination[0] == 'q':
|
|
destination = _Low(destination)
|
|
self.EmitOp2('vqmovn.%s' % mov_type, destination, source)
|
|
|
|
def EmitVQmovn2(self, mov_type, destination, source_1, source_2):
|
|
self.EmitVQmovn(mov_type, _Low(destination), source_1)
|
|
self.EmitVQmovn(mov_type, _High(destination), source_2)
|
|
|
|
def EmitVQmovun(self, mov_type, destination, source):
|
|
if destination[0] == 'q':
|
|
destination = _Low(destination)
|
|
self.EmitOp2('vqmovun.%s' % mov_type, destination, source)
|
|
|
|
def EmitVQmovun2(self, mov_type, destination, source_1, source_2):
|
|
self.EmitVQmovun(mov_type, _Low(destination), source_1)
|
|
self.EmitVQmovun(mov_type, _High(destination), source_2)
|
|
|
|
def EmitVMul(self, mul_type, destination, source_1, source_2):
|
|
destination, source_1, source_2 = _MakeCompatible(destination, source_1,
|
|
source_2)
|
|
self.EmitOp3('vmul.%s' % mul_type, destination, source_1, source_2)
|
|
|
|
def EmitVMulScalar(self, mul_type, destination, source_1, source_2):
|
|
self.EmitOp3('vmul.%s' % mul_type, destination, source_1, source_2)
|
|
|
|
def EmitVMull(self, mul_type, destination, source_1, source_2):
|
|
self.EmitOp3('vmull.%s' % mul_type, destination, source_1, source_2)
|
|
|
|
def EmitVPadd(self, add_type, destination, source_1, source_2):
|
|
self.EmitOp3('vpadd.%s' % add_type, destination, source_1, source_2)
|
|
|
|
def EmitVPaddl(self, add_type, destination, source):
|
|
self.EmitOp2('vpaddl.%s' % add_type, destination, source)
|
|
|
|
def EmitVPadal(self, add_type, destination, source):
|
|
self.EmitOp2('vpadal.%s' % add_type, destination, source)
|
|
|
|
def EmitLdr(self, register, value):
|
|
self.EmitOp2('ldr', register, value)
|
|
|
|
def EmitVLoad(self, load_no, load_type, destination, source):
|
|
self.EmitVLoadA(load_no, load_type, [destination], source)
|
|
|
|
def EmitVLoadA(self, load_no, load_type, destinations, source):
|
|
self.EmitOp2('vld%d.%d' % (load_no, load_type),
|
|
'{%s}' % ', '.join(_ExpandQuads(destinations)), source)
|
|
|
|
def EmitVLoadAE(self,
|
|
load_type,
|
|
elem_count,
|
|
destinations,
|
|
source,
|
|
alignment=None):
|
|
bits_to_load = load_type * elem_count
|
|
destinations = _ExpandQuads(destinations)
|
|
if len(destinations) * 64 < bits_to_load:
|
|
raise ArgumentError('To few destinations: %d to load %d bits.' %
|
|
(len(destinations), bits_to_load))
|
|
|
|
while bits_to_load > 0:
|
|
if bits_to_load >= 256:
|
|
self.EmitVLoadA(1, 32, destinations[:4],
|
|
self.DereferenceIncrement(source, alignment))
|
|
bits_to_load -= 256
|
|
destinations = destinations[4:]
|
|
elif bits_to_load >= 192:
|
|
self.EmitVLoadA(1, 32, destinations[:3],
|
|
self.DereferenceIncrement(source, alignment))
|
|
bits_to_load -= 192
|
|
destinations = destinations[3:]
|
|
elif bits_to_load >= 128:
|
|
self.EmitVLoadA(1, 32, destinations[:2],
|
|
self.DereferenceIncrement(source, alignment))
|
|
bits_to_load -= 128
|
|
destinations = destinations[2:]
|
|
elif bits_to_load >= 64:
|
|
self.EmitVLoad(1, 32, destinations[0],
|
|
self.DereferenceIncrement(source, alignment))
|
|
bits_to_load -= 64
|
|
destinations = destinations[1:]
|
|
else:
|
|
destination = destinations[0]
|
|
if bits_to_load == 56:
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
self.EmitVLoad(1, 16,
|
|
self.Lane(16, destination, 2),
|
|
self.DereferenceIncrement(source))
|
|
self.EmitVLoad(1, 8,
|
|
self.Lane(8, destination, 6),
|
|
self.DereferenceIncrement(source))
|
|
elif bits_to_load == 48:
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
self.EmitVLoad(1, 16,
|
|
self.Lane(16, destination, 2),
|
|
self.DereferenceIncrement(source))
|
|
elif bits_to_load == 40:
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
self.EmitVLoad(1, 8,
|
|
self.Lane(8, destination, 4),
|
|
self.DereferenceIncrement(source))
|
|
elif bits_to_load == 32:
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
elif bits_to_load == 24:
|
|
self.EmitVLoad(1, 16,
|
|
self.Lane(16, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
self.EmitVLoad(1, 8,
|
|
self.Lane(8, destination, 2),
|
|
self.DereferenceIncrement(source))
|
|
elif bits_to_load == 16:
|
|
self.EmitVLoad(1, 16,
|
|
self.Lane(16, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
elif bits_to_load == 8:
|
|
self.EmitVLoad(1, 8,
|
|
self.Lane(8, destination, 0),
|
|
self.DereferenceIncrement(source))
|
|
else:
|
|
raise ArgumentError('Wrong leftover: %d' % bits_to_load)
|
|
return
|
|
|
|
def EmitVLoadE(self, load_type, count, destination, source, alignment=None):
|
|
self.EmitVLoadAE(load_type, count, [destination], source, alignment)
|
|
|
|
def EmitVLoadAllLanes(self, load_no, load_type, destination, source):
|
|
destinations = []
|
|
if destination[0] == 'q':
|
|
destinations.append(self.AllLanes(_Low(destination)))
|
|
destinations.append(self.AllLanes(_High(destination)))
|
|
else:
|
|
destinations.append(self.AllLanes(destination))
|
|
self.EmitVLoadA(load_no, load_type, destinations, source)
|
|
|
|
def EmitVLoadOffset(self, load_no, load_type, destination, source, offset):
|
|
self.EmitVLoadOffsetA(load_no, load_type, [destination], source, offset)
|
|
|
|
def EmitVLoadOffsetA(self, load_no, load_type, destinations, source, offset):
|
|
assert len(destinations) <= 4
|
|
self.EmitOp3('vld%d.%d' % (load_no, load_type),
|
|
'{%s}' % ', '.join(_ExpandQuads(destinations)), source, offset)
|
|
|
|
def EmitPld(self, load_address_register):
|
|
self.EmitOp1('pld', '[%s]' % load_address_register)
|
|
|
|
def EmitPldw(self, store_address_register):
|
|
self.EmitOp1('pldw', '[%s]' % store_address_register)
|
|
|
|
def EmitPldOffset(self, load_address_register, offset):
|
|
self.EmitOp1('pld', '[%s, %s]' % (load_address_register, offset))
|
|
|
|
def EmitPldwOffset(self, store_address_register, offset):
|
|
self.EmitOp1('pldw', '[%s, %s]' % (store_address_register, offset))
|
|
|
|
def EmitVShl(self, shift_type, destination, source, shift):
|
|
self.EmitOp3('vshl.%s' % shift_type, destination, source, shift)
|
|
|
|
def EmitVStore(self, store_no, store_type, source, destination):
|
|
self.EmitVStoreA(store_no, store_type, [source], destination)
|
|
|
|
def EmitVStoreA(self, store_no, store_type, sources, destination):
|
|
self.EmitOp2('vst%d.%d' % (store_no, store_type),
|
|
'{%s}' % ', '.join(_ExpandQuads(sources)), destination)
|
|
|
|
def EmitVStoreAE(self,
|
|
store_type,
|
|
elem_count,
|
|
sources,
|
|
destination,
|
|
alignment=None):
|
|
bits_to_store = store_type * elem_count
|
|
sources = _ExpandQuads(sources)
|
|
if len(sources) * 64 < bits_to_store:
|
|
raise ArgumentError('To few sources: %d to store %d bits.' %
|
|
(len(sources), bits_to_store))
|
|
|
|
while bits_to_store > 0:
|
|
if bits_to_store >= 256:
|
|
self.EmitVStoreA(1, 32, sources[:4],
|
|
self.DereferenceIncrement(destination, alignment))
|
|
bits_to_store -= 256
|
|
sources = sources[4:]
|
|
elif bits_to_store >= 192:
|
|
self.EmitVStoreA(1, 32, sources[:3],
|
|
self.DereferenceIncrement(destination, alignment))
|
|
bits_to_store -= 192
|
|
sources = sources[3:]
|
|
elif bits_to_store >= 128:
|
|
self.EmitVStoreA(1, 32, sources[:2],
|
|
self.DereferenceIncrement(destination, alignment))
|
|
bits_to_store -= 128
|
|
sources = sources[2:]
|
|
elif bits_to_store >= 64:
|
|
self.EmitVStore(1, 32, sources[0],
|
|
self.DereferenceIncrement(destination, alignment))
|
|
bits_to_store -= 64
|
|
sources = sources[1:]
|
|
else:
|
|
source = sources[0]
|
|
if bits_to_store == 56:
|
|
self.EmitVStore(1, 32,
|
|
self.Lane(32, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
self.EmitVStore(1, 16,
|
|
self.Lane(16, source, 2),
|
|
self.DereferenceIncrement(destination))
|
|
self.EmitVStore(1, 8,
|
|
self.Lane(8, source, 6),
|
|
self.DereferenceIncrement(destination))
|
|
elif bits_to_store == 48:
|
|
self.EmitVStore(1, 32,
|
|
self.Lane(32, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
self.EmitVStore(1, 16,
|
|
self.Lane(16, source, 2),
|
|
self.DereferenceIncrement(destination))
|
|
elif bits_to_store == 40:
|
|
self.EmitVStore(1, 32,
|
|
self.Lane(32, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
self.EmitVStore(1, 8,
|
|
self.Lane(8, source, 4),
|
|
self.DereferenceIncrement(destination))
|
|
elif bits_to_store == 32:
|
|
self.EmitVStore(1, 32,
|
|
self.Lane(32, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
elif bits_to_store == 24:
|
|
self.EmitVStore(1, 16,
|
|
self.Lane(16, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
self.EmitVStore(1, 8,
|
|
self.Lane(8, source, 2),
|
|
self.DereferenceIncrement(destination))
|
|
elif bits_to_store == 16:
|
|
self.EmitVStore(1, 16,
|
|
self.Lane(16, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
elif bits_to_store == 8:
|
|
self.EmitVStore(1, 8,
|
|
self.Lane(8, source, 0),
|
|
self.DereferenceIncrement(destination))
|
|
else:
|
|
raise ArgumentError('Wrong leftover: %d' % bits_to_store)
|
|
return
|
|
|
|
def EmitVStoreE(self, store_type, count, source, destination, alignment=None):
|
|
self.EmitVStoreAE(store_type, count, [source], destination, alignment)
|
|
|
|
def EmitVStoreOffset(self, store_no, store_type, source, destination, offset):
|
|
self.EmitVStoreOffsetA(store_no, store_type, [source], destination, offset)
|
|
|
|
def EmitVStoreOffsetA(self, store_no, store_type, sources, destination,
|
|
offset):
|
|
self.EmitOp3('vst%d.%d' % (store_no, store_type),
|
|
'{%s}' % ', '.join(_ExpandQuads(sources)), destination, offset)
|
|
|
|
def EmitVStoreOffsetE(self, store_type, count, source, destination, offset):
|
|
"""Emit assembly to store a number elements from the source registers."""
|
|
if store_type is not 32:
|
|
raise ArgumentError('Unsupported store_type: %d' % store_type)
|
|
|
|
sources = []
|
|
if source[0] == 'q':
|
|
sources.append(_Low(source))
|
|
sources.append(_High(source))
|
|
if count * store_type > 128:
|
|
raise ArgumentError('To many %dbit elements in a q register: %d' %
|
|
(store_type, count))
|
|
else:
|
|
sources.append(source)
|
|
if count * store_type > 64:
|
|
raise ArgumentError('To many %dbit elements in a d register: %d' %
|
|
(store_type, count))
|
|
|
|
if count == 1:
|
|
self.EmitVStoreOffset(1, store_type,
|
|
self.Lane(store_type, sources[0], 0),
|
|
self.Dereference(destination, None), offset)
|
|
elif count == 2:
|
|
self.EmitVStoreOffset(1, store_type, sources[0],
|
|
self.Dereference(destination, None), offset)
|
|
elif count == 3:
|
|
self.EmitVStore(1, store_type, sources[0],
|
|
self.DereferenceIncrement(destination, None))
|
|
self.EmitVStoreOffset(1, store_type,
|
|
self.Lane(store_type, sources[1], 0),
|
|
self.Dereference(destination, None), offset)
|
|
self.EmitSub(destination, destination, self.ImmediateConstant(8))
|
|
elif count == 4:
|
|
self.EmitVStoreOffsetA(1, store_type, sources,
|
|
self.Dereference(destination, None), offset)
|
|
else:
|
|
raise ArgumentError('To many elements: %d' % count)
|
|
|
|
def EmitVSumReduce(self, reduce_type, elem_count, reduce_count, destinations,
|
|
sources):
|
|
"""Emit assembly for n-fold horizontal sum reduction."""
|
|
if reduce_type is not 'u32':
|
|
raise ArgumentError('Unsupported reduce: %s' % reduce_type)
|
|
|
|
sources = _ExpandQuads(sources)
|
|
|
|
destinations = _ExpandQuads(destinations)
|
|
|
|
if len(destinations) * 2 < elem_count:
|
|
raise ArgumentError('Not enough space in destination: %d vs %d' %
|
|
(len(destinations) * 2, elem_count))
|
|
|
|
if len(sources) * 2 != elem_count * reduce_count:
|
|
raise ArgumentError('Wrong number of sources: %d vs %d' %
|
|
(len(sources) * 2, elem_count * reduce_count))
|
|
|
|
if reduce_count <= 1:
|
|
raise ArgumentError('Unsupported reduce_count: %d' % reduce_count)
|
|
|
|
while reduce_count > 1:
|
|
if len(sources) % 2 == 1:
|
|
sources.append(sources[-1])
|
|
|
|
if reduce_count == 2:
|
|
for i in range(len(sources) / 2):
|
|
self.EmitVPadd(reduce_type, destinations[i], sources[2 * i],
|
|
sources[2 * i + 1])
|
|
return
|
|
else:
|
|
sources_2 = []
|
|
for i in range(len(sources) / 2):
|
|
self.EmitVPadd(reduce_type, sources[2 * i], sources[2 * i],
|
|
sources[2 * i + 1])
|
|
sources_2.append(sources[2 * i])
|
|
reduce_count /= 2
|
|
sources = sources_2
|
|
|
|
def EmitVUzp(self, uzp_type, operand_1, operand_2):
|
|
self.EmitOp2('vuzp.%d' % uzp_type, operand_1, operand_2)
|
|
|
|
def EmitVTrn(self, trn_type, operand_1, operand_2):
|
|
self.EmitOp2('vtrn.%d' % trn_type, operand_1, operand_2)
|
|
|
|
def EmitColBlockStride(self, cols, stride, new_stride):
|
|
assert cols in [1, 2, 3, 4, 5, 6, 7, 8]
|
|
if cols in [5, 6, 7]:
|
|
self.EmitSub(new_stride, stride, self.ImmediateConstant(4))
|
|
|
|
def EmitLoadColBlock(self, unused_registers, load_type, cols, elements, block,
|
|
input_address, stride):
|
|
"""Load a block of column major data."""
|
|
assert cols is len(block)
|
|
assert load_type is 8
|
|
|
|
input_deref = self.Dereference(input_address, None)
|
|
input_deref_increment = self.DereferenceIncrement(input_address, None)
|
|
|
|
if cols is 1:
|
|
for i in range(elements):
|
|
self.EmitVLoadOffset(1, 8,
|
|
self.Lane(8, block[0], i), input_deref, stride)
|
|
self.EmitPld(input_address)
|
|
elif cols is 2:
|
|
for i in range(elements):
|
|
self.EmitVLoadOffset(1, 16,
|
|
self.Lane(16, block[i / 4], i % 4), input_deref,
|
|
stride)
|
|
self.EmitPld(input_address)
|
|
self.EmitVUzp(8, block[0], block[1])
|
|
elif cols is 3:
|
|
for i in range(elements):
|
|
self.EmitVLoadOffsetA(3, 8, [self.Lane(8, row, i) for row in block],
|
|
input_deref, stride)
|
|
elif cols is 4:
|
|
for i in range(elements):
|
|
self.EmitVLoadOffset(1, 32,
|
|
self.Lane(32, block[i % 4], i / 4), input_deref,
|
|
stride)
|
|
self.EmitPld(input_address)
|
|
self.EmitVTrn(16, block[0], block[2])
|
|
self.EmitVTrn(16, block[1], block[3])
|
|
self.EmitVTrn(8, block[0], block[1])
|
|
self.EmitVTrn(8, block[2], block[3])
|
|
elif cols is 5:
|
|
for i in range(elements):
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, block[i % 4], i / 4),
|
|
input_deref_increment)
|
|
self.EmitVLoadOffset(1, 8,
|
|
self.Lane(8, block[4], i), input_deref, stride)
|
|
self.EmitPld(input_address)
|
|
self.EmitVTrn(16, block[0], block[2])
|
|
self.EmitVTrn(16, block[1], block[3])
|
|
self.EmitVTrn(8, block[0], block[1])
|
|
self.EmitVTrn(8, block[2], block[3])
|
|
elif cols is 6:
|
|
for i in range(elements):
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, block[i % 4], i / 4),
|
|
input_deref_increment)
|
|
self.EmitVLoadOffset(1, 16,
|
|
self.Lane(16, block[4 + i / 4], i % 4),
|
|
input_deref, stride)
|
|
self.EmitPld(input_address)
|
|
self.EmitVTrn(16, block[0], block[2])
|
|
self.EmitVTrn(16, block[1], block[3])
|
|
self.EmitVUzp(8, block[4], block[5])
|
|
self.EmitVTrn(8, block[0], block[1])
|
|
self.EmitVTrn(8, block[2], block[3])
|
|
elif cols is 7:
|
|
for i in range(elements):
|
|
self.EmitVLoad(1, 32,
|
|
self.Lane(32, block[i % 4], i / 4),
|
|
input_deref_increment)
|
|
self.EmitVLoadOffsetA(3, 8,
|
|
[self.Lane(8, row, i) for row in block[4:]],
|
|
input_deref, stride)
|
|
self.EmitPld(input_address)
|
|
self.EmitVTrn(16, block[0], block[2])
|
|
self.EmitVTrn(16, block[1], block[3])
|
|
self.EmitVTrn(8, block[0], block[1])
|
|
self.EmitVTrn(8, block[2], block[3])
|
|
elif cols is 8:
|
|
for i in range(elements):
|
|
self.EmitVLoadOffset(1, 32, block[i], input_deref, stride)
|
|
self.EmitPld(input_address)
|
|
self.EmitVTrn(8, block[0], block[1])
|
|
self.EmitVTrn(8, block[2], block[3])
|
|
self.EmitVTrn(8, block[4], block[5])
|
|
self.EmitVTrn(8, block[6], block[7])
|
|
self.EmitVTrn(16, block[0], block[2])
|
|
self.EmitVTrn(16, block[1], block[3])
|
|
self.EmitVTrn(16, block[4], block[6])
|
|
self.EmitVTrn(16, block[5], block[7])
|
|
self.EmitVTrn(32, block[0], block[4])
|
|
self.EmitVTrn(32, block[1], block[5])
|
|
self.EmitVTrn(32, block[2], block[6])
|
|
self.EmitVTrn(32, block[3], block[7])
|
|
else:
|
|
assert False
|
|
return block
|
|
|
|
def Dereference(self, value, alignment=None):
|
|
if alignment:
|
|
return '[%s:%d]' % (value, alignment)
|
|
else:
|
|
return '[%s]' % value
|
|
|
|
def DereferenceIncrement(self, value, alignment=None):
|
|
return '%s!' % self.Dereference(value, alignment)
|
|
|
|
def ImmediateConstant(self, value):
|
|
return '#%d' % value
|
|
|
|
def AllLanes(self, value):
|
|
return '%s[]' % value
|
|
|
|
def Lane(self, bits, value, lane):
|
|
"""Get the proper n-bit lane from the given register."""
|
|
registers = []
|
|
if value[0] == 'q':
|
|
registers.append(_Low(value))
|
|
registers.append(_High(value))
|
|
else:
|
|
registers.append(value)
|
|
|
|
elems_per_register = 64 / bits
|
|
register = lane / elems_per_register
|
|
lane %= elems_per_register
|
|
|
|
return '%s[%d]' % (registers[register], lane)
|
|
|
|
def CreateRegisters(self):
|
|
return _NeonRegisters32Bit()
|