# RUN: %PYTHON %s | FileCheck %s import gc import io import itertools from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: test_insert_at_block_end def test_insert_at_block_end(): ctx = Context() ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.parse(r""" func @foo() -> () { "custom.op1"() : () -> () } """) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint(entry_block) ip.insert(Operation.create("custom.op2")) # CHECK: "custom.op1" # CHECK: "custom.op2" module.operation.print() run(test_insert_at_block_end) # CHECK-LABEL: TEST: test_insert_before_operation def test_insert_before_operation(): ctx = Context() ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.parse(r""" func @foo() -> () { "custom.op1"() : () -> () "custom.op2"() : () -> () } """) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint(entry_block.operations[1]) ip.insert(Operation.create("custom.op3")) # CHECK: "custom.op1" # CHECK: "custom.op3" # CHECK: "custom.op2" module.operation.print() run(test_insert_before_operation) # CHECK-LABEL: TEST: test_insert_at_block_begin def test_insert_at_block_begin(): ctx = Context() ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.parse(r""" func @foo() -> () { "custom.op2"() : () -> () } """) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint.at_block_begin(entry_block) ip.insert(Operation.create("custom.op1")) # CHECK: "custom.op1" # CHECK: "custom.op2" module.operation.print() run(test_insert_at_block_begin) # CHECK-LABEL: TEST: test_insert_at_block_begin_empty def test_insert_at_block_begin_empty(): # TODO: Write this test case when we can create such a situation. pass run(test_insert_at_block_begin_empty) # CHECK-LABEL: TEST: test_insert_at_terminator def test_insert_at_terminator(): ctx = Context() ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.parse(r""" func @foo() -> () { "custom.op1"() : () -> () return } """) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint.at_block_terminator(entry_block) ip.insert(Operation.create("custom.op2")) # CHECK: "custom.op1" # CHECK: "custom.op2" module.operation.print() run(test_insert_at_terminator) # CHECK-LABEL: TEST: test_insert_at_block_terminator_missing def test_insert_at_block_terminator_missing(): ctx = Context() ctx.allow_unregistered_dialects = True with ctx: module = Module.parse(r""" func @foo() -> () { "custom.op1"() : () -> () } """) entry_block = module.body.operations[0].regions[0].blocks[0] try: ip = InsertionPoint.at_block_terminator(entry_block) except ValueError as e: # CHECK: Block has no terminator print(e) else: assert False, "Expected exception" run(test_insert_at_block_terminator_missing) # CHECK-LABEL: TEST: test_insertion_point_context def test_insertion_point_context(): ctx = Context() ctx.allow_unregistered_dialects = True with Location.unknown(ctx): module = Module.parse(r""" func @foo() -> () { "custom.op1"() : () -> () } """) entry_block = module.body.operations[0].regions[0].blocks[0] with InsertionPoint(entry_block): Operation.create("custom.op2") with InsertionPoint.at_block_begin(entry_block): Operation.create("custom.opa") Operation.create("custom.opb") Operation.create("custom.op3") # CHECK: "custom.opa" # CHECK: "custom.opb" # CHECK: "custom.op1" # CHECK: "custom.op2" # CHECK: "custom.op3" module.operation.print() run(test_insertion_point_context)