Dynamic Compilers

jit (just in time) compilers vs aot(ahead of time) compilers

a jit compiler translates code into isa while the program executes

pro:

more information

con:

compile time slows down execution

some options

  • compile a function the first time it is called
  • compile a function after it has been called a lot (needs an interpreter) We call these hot functions
  • build a trace of instructions executed and compile the hot traces (a trace has no branches)
  • A variation I used: ran the program to completion using a tracing interpreter, recompile off line, future execution is a mix of interpreter and compiled code

Can jit compiled code run faster then aot code?

Comparison

aot jit
cannot inline libraries can inline (even class methods)
no runtime code gen can use run time code gen
no speculative opts can use spec opts
less information more information
overall performance lower overall performance often higher
full speed from the start requires warmup
no compile cost at run time overhead to run compiler

Tradeoffs

  1. The time to compile is part of the total execution time
  2. might run less optimizations to speed up execution time
  3. might look at run time info
  4. same code might be compiled many times

Why would the same code be compiled more than once?

tiered compilers

Since compilation is costly, do not compile functions that are only called once and do not contain a long running loop

we have a series of compilers, each with more aggressive optimization and each allowed to take longer

  • the lowest tier is the interpreter
  • the next is the base line compiler

  1. start interpreting the code
  2. if some part of the code takes a long time, compile it with the next higher tier
  3. is some runtime info changes, compile it again

magic numbers

associate a counter with branches and functions if the counter reaches some magic number use one of the compilers

if the counter for a backward branch, you recompile, but the code is executing in the middle of a loop, how do you insert the newly compiled code?

questions when building a JIT

  • what strategy do you use to invoke the jit
  • do you have to execute for a while before calling the jit
  • how much info do you need
  • what is the price of wrong info
  • are there easy and hard programs
  • do the easy programs match up with users common programs

Speculation

  • assume some property is true, compile using that info this is always a gamble, so you need to recover if the assumption was wrong
  • assume a variable is an int, and does not overflow
  • assume properties of an object is fixed
  • assume the target of call is always the same
  • assume past behavior predicts future behavior

flow

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph LR
interpreter -- hot? --> profiling 
profiling -- stats --> optimizing_compiler
optimizing_compiler --> compiled_code
compiled_code -- deoptimze --> interpreter
interpreter -- already_compiled --> compiled_code

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph LR
interpreter -- hot? --> profiling 
profiling -- stats --> optimizing_compiler
optimizing_compiler --> compiled_code
compiled_code -- deoptimze --> interpreter
interpreter -- already_compiled --> compiled_code

boxed values

Many languages do not use strong static typing

for example in python

x = x + 1

x could be an int/float/object/string etc

the value of x needs to carry a type. Represent x as a pair (type, pointer or bits) The pair is called a boxed value

then to generate code for the plus we have to figure out what kind of add, based on the type

inline caches

in languages like python, calls to a method are more expensive then calls to a method in c++ why?

. . .

Python objects are implemented as hash tables. While C++ uses virtual tables

how does that effect the cost?

first C++ virtual tables

in C++ a method call takes two dereferences

  1. first find the v-table
  2. second used a fixed offset from the table start to find the address

What do we need to keep the offset fixed?

if derived inherits from base, and both have a function f. the offset to f has to be the same.

in languages where objects are hash tables, the c++ dereference becomes a hash table lookup, which is slower

tradeoffs

In a dynamically typed language like python we can add or remove methods easily

but method calls are expensive

we want to make these calls cheaper

inline caches at each call site

the first time we call a method, we know the type (because we are generating code at runtime)

def func(a,b,c):
  for i in range(10):
     foo(a,b,c)
def func(a,b,c):
  for i in range(10):
    if isinstance(a, type1)
      body of foo  
    else:
      other = lookup 'foo' in the hash
      call other(a,b,c
      )

inline caches at the function site

def func(a,b,c):
  for i in range(10):
     _foo(a,b,c
def _foo(a,b,c)
  if isinstance(a, type1)
      body of foo  
    else:
      other = lookup 'foo' in a
      call other(a,b,c)

is it better to do this at the call site or at the function site?

polymorphic calls

if the type changes at runtime (the call to other is taken) does the optimization help?

could invalidate the table and rebuild it with another case

what are the costs

for example v8 compiler

monomorphic inline hit - 10 instructions

polymorphic hit - 35 instructions for 10 types, 60 instructions for 20 types

cache miss 1000-4000 instructions

value specialization

Oddly many functions are called with the same arguments

an example

given a vector v of size n, and a parameter q find the element of v that is closest to q

 function closest(v, q, n) {
    if (n == 0) {
          throw "Error";
    } else {
        var i = 0;
        var d = 0ffffffff;
        while (i < n) {
           var nd = abs(v[i] - q);
           if (nd <= d) d = nd; 
           i++;
        }    
        return d;  
      } 
}

the cfg

we want to recompile this for specific v,q, and n, where we restart at the while test


 function closest(v, q, n) {
    if (n == 0) {
          throw "Error";
    } else {
      var i = 0;
      var d = 0ffffffff;
      while (i < n) {
         var nd = abs(v[i] - q);
         if (nd <= d) d = nd; 
         i++;
        }    
        return d;  
      } 
}
%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = param[0]
              q = param[1]
              n = param[2]
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i0, i2, i3)
    d1 = phi(d0, d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = param[0]
                  q = param[1]
                  n = param[2]
                  i3 = stack[0]
                  d4 = stack[1]"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = param[0]
              q = param[1]
              n = param[2]
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i0, i2, i3)
    d1 = phi(d0, d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = param[0]
                  q = param[1]
                  n = param[2]
                  i3 = stack[0]
                  d4 = stack[1]"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

two entries

First entry is the regular starting point, second is the entry if we are currently running the loop in the interpreter

Since we are compiling the function while in the loop we can ask the interpreter for values

  • v == load[0]
  • q = 42
  • n = 100
  • i = 40
  • d = 0fffffff

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = param[0]
              q = param[1]
              n = param[2]
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i0, i2, i3)
    d1 = phi(d0, d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = param[0]
                  q = param[1]
                  n = param[2]
                  i3 = stack[0]
                  d4 = stack[1]"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = param[0]
              q = param[1]
              n = param[2]
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i0, i2, i3)
    d1 = phi(d0, d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = param[0]
                  q = param[1]
                  n = param[2]
                  i3 = stack[0]
                  d4 = stack[1]"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = load[0]
              q = q = 42 
              n = 100
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i2, i3)
    d1 = phi(d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = load[0]
              q = q = 42 
              n = 100
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i2, i3)
    d1 = phi(d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

dead code elimination

After this all calls to the function assume these arguments so no need to keep the regular entry

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = load[0]
              q = q = 42 
              n = 100
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i2, i3)
    d1 = phi(d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
normal_entry["function entry
              v = load[0]
              q = q = 42 
              n = 100
              if (n ==0) goto l1"]

l1["l1: throw error"]
l2[" l2: i0 = 0
     d = 0fffffff"]
normal_entry --> l1
normal_entry--> l2
l3["l3: i1 = phi(i2, i3)
    d1 = phi(d3, d4)
    if (i1 < n) go to l5"  ]
l2--> l3
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD

l3["l3: i1 = phi(i2, i3)
    d1 = phi(dd3, d4)
    if (i1 < n) go to l5"  ]
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD

l3["l3: i1 = phi(i2, i3)
    d1 = phi(dd3, d4)
    if (i1 < n) go to l5"  ]
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

array in bounds check

we can pattern match loops with bounds checks if we know the limit


%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD

l3["l3: i1 = phi(i2, i3)
    d1 = phi(dd3, d4)
    if (i1 < n) go to l5"  ]
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD

l3["l3: i1 = phi(i2, i3)
    d1 = phi(dd3, d4)
    if (i1 < n) go to l5"  ]
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     notinbounds(t1, n) go to l8"]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l8["l8: throw boundsError"]
l5 --> l8
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD

l3["l3: i1 = phi(i0, i2, i3)
    d1 = phi(d0, d3, d4)
    if (i1 < n) go to l5"  ]
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     "]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l9--> l3

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD

l3["l3: i1 = phi(i0, i2, i3)
    d1 = phi(d0, d3, d4)
    if (i1 < n) go to l5"  ]
entry_on_stack_rep["start replace
                   v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l5["l5: t0 = 4* i
     t1 = v[t0]
     "]

l3 --> l5 
l3--> l4
l4["l4: return d1"]
l5--> l7
l7[" l7: nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   goto l3"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l9--> l3

loop inversion

a general while loop, loop might never run so cannot move code out of the loop

while(cond){
  ...
}

can be changed into

if (cond){
  do {
    ...
  } while(cond)
}

for this loop the first time around i = 40, n = 100 so the first condition is true

after loop inversion

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
l3["l3: i1 = phi(i2, i3)
    d1 = phi(d3, d4)" ]
entry_on_stack_rep["v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l3 --> l7


l4["l4: return d1"]
l7[" l7: l5: t0 = 4* i
t1 = v[t0]
nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   if (i2 > n) goto l4"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l9--> l3
l9--> l4

%%{init: {"flowchart": {"htmlLabels": false}} }%%
graph TD
l3["l3: i1 = phi(i2, i3)
    d1 = phi(d3, d4)" ]
entry_on_stack_rep["v = load [0]
                  q = 42
                  n = 100
                  i3 = 40
                  d4 = offfffff"]
entry_on_stack_rep --> l3
l3 --> l7


l4["l4: return d1"]
l7[" l7: l5: t0 = 4* i
t1 = v[t0]
nd = abs(t1, q)
   if (nd > d1) go to l9"]

l9["l9: d3 = phi(d1, d2)
   i2 = i1 + 1
   if (i2 > n) goto l4"]
l7--> l9
l7--> l6["l6: d2 = nd"]
l6--> l9
l9--> l3
l9--> l4

results

specialized code is shorter and compiles faster

since we know that the loop goes from 42 to 100, we could unroll the loop

dynamic compilation and pgo (profile guided optimization)

workload dependent – dynamically you only measure what you actually see – that can lead to sub-optimal choices if you make them “forever”

If your program has phases, where it does some kind of behavior for a while, and then shifts to another behavior afterwards, then that workload that you originally measured isn’t representative anymore!

When you do PGO you’re supposed to do it on the entire program and it’s supposed to be representative, whereas when you’re tracing you’re doing that live!

confidence vs time

set the magic numbers high- run a long time slowly, but gather more info

set the magic number low, quickly recompile based on less info

Dynamic Compilers part 2

C++ method call

each object has a pointer to a vtable, which is a table of function pointers to all the virtual functions

each derived object has its own vtable, which has the same offsets for all the common virtual functions

to find the code takes two dereferences

  1. find the vtable (one per class) (one load)
  2. at a fixed offset (determined by the virtual function name) find the code (second load)

a python example

more flexible

  1. find the hash table (one per instance)
  2. lookup the virtual function in the hash table

class Thing:
    def __init__(self, kind):
        self.kind = kind 

thing = Thing('car')

def honk(self):
    print(f"{self.kind} says Honk")

thing.honk = honk.__get__(thing)  ## add a method dynamically to one instance 

thing.honk()  ## call it 

honk.__get__(thing) returns a bound method, when this method is called thing is passed as first argument

dynamic chunks

So far:

  1. run interpreter or tier 0 compiler
  2. collect statistics on call counts or branch counts
  3. when count is high enough recompile the hot functions
  4. specialize the hot functions based on common values

The unit of compilation is the static function

an alterative is called trace compilation

  1. run interpreter or tier 0 compiler
  2. collect the statements executed (no not collect control flow)
  3. this produces a linear trace of instructions
  4. recompile the trace using optimizations like value numbering
  5. if the next time the code executes, it takes a different path, fix things up

trace compilation 0

In a linear trace the number of assumptions you’re making accumulates as you execute [towards the end of the trace you have the most assumptions built up]

If you have an always-taken control flow, e.g. some virtual function call that’s always calling the same actual function, a tracing compiler will treat all back-to-back branches as one set of straight line code

execute this all back to back, and, whenever convenient, check whether any of those assumptions were wrong”

cold path

On the “cold path” – again, when it’s convenient undo all the inapplicable things if it turns out the branches weren’t true

Called a “bailout” [“bailing out” of the trace]

At a bailout there is new information. something you didn’t observe when you were tracing

You trust that everything you’ve trace is going to happen, it’s all going to well, and you’re going to be able to optimize for it

But then at runtime, when convenient, you’re going to check, and then bail if you were wrong, and have the rollback on the cold path

So the hot path, the one you’re pretty sure is going to execute, is quite optimal

trace compilation 2

tracing jit: extract a hot path (not a function)

Hot paths are compiled as a single basic block, but the path might go through a call

gamble: next execution starting at this point, go the same way, no branches leave the path

generate machine code for hot paths interpret the rest of the program

unlike specialization, tracing assumes the same path but not the same values

an example (x = 42)

function main(x){
   y = x +1 
   if x <100 {
      z = f(y)
   } else {
      z = g(y)
   }
   return z
}

function f(a){
   return a -1 
}
  • y = x +1
  • guard(x < 100)
  • a = y
  • z = a - 1
  • return z

guards at divergence, guards never return

optimize assuming guards are true, ok to be slow if guard is false

move guards up

why is this a good idea?

. . .

  • fail fast
  • longer region to optimize

use local value numbering

  • guard(x < 100)
  • y = x + 1
  • a = y
  • z = a - 1
  • return z
  • guard(x < 100)
  • return x

how do this in Bril?

3 new operations (sort of like out-of-order instructions)

  1. speculate - Enter a speculative execution context. No arguments.
  2. commit - End the current speculative context, committing the current speculative state as the “real” state. No arguments.
  3. guard - Check a condition and possibly abort the current speculative context. One argument, the Boolean condition, and one label, to which control is transferred on abort.

speculate extension

speculative execution extension

example

b: bool = const false;

v: int = const 4; v == 4
speculate; v: int = const 2; v == 2 (speculate state) guard b .failed; v == 2 (speculate state) commit;

.failed: print v; v == 4

implementation

you can add a tracer to an interpreter

In a lot of language environments you’ll have an interpreter that’s executing “op at a time”

hook in a tracer which observes what the interpreter is doing and “make some machine code on the side” based on how the interpreter ran

you can implement just a subset of the operations [ed: you might call this property “compiler completeness” for your op set

common bytecode operations

implement only the common ones and simply end the trace when you hit one that was not implemented, because it was uncommon

You can build up this trace JIT-ing capability over time, because the system is built with this assumption you can bail out of the trace for whatever reason and go back to thr interpreter

an example

Could imagine making a JIT that just: Covered MULs and ADDs and could make fused/composite MUL/ADD bytecode combinations

Specialize that for one common type; e.g. if you have many types in your language, could support that just for integer types, or just for FP ops e.g. if it were numerical code, and then just bail if any other types showed up at runtime;

trace invariants: suppose traces call to other traces;

trace1 set of ops A, trace2 with set of ops B and we see a transfer from A to B

make sure that the assumptions between those two things are lining up – called trace fusion

know the invariants (i.e. “what must be true”) on the exit from A and the entry to B are lining up / compatible with each other

method inlining

In trace compiler you just execute through methods

Inlining kind of the natural path of compilation when doing trace compilation – just linear execution where the jumps/calls/returns simply disappear

tail duplication

it is common that multiple traces have a common tail

for() {
  if op_Eq{
     op1 
   } else {
      op2
   }
}
op_t
op_a
op_i
op_l
trace0:  op_eq quard\true  op1 op_t op_a op_i op_l 

trace1:              \false op2 op_t op_a op_i op_l 

two traces with the same ending, could generate

one copy of the tail- with arguments showing the header

two copies of the tail- frozen header

adding traces to bril

How to modify the reference interpreter (warning typescript!)

brili

there are two functions to consider

  1. evalFunc interprets a function by calling evalInstr on each instruction
  2. evalInstr interprets one instruction, large case statement for each instruction

traces and users

when a trace works well- it looks amazing - it finds the inner loop and optimizes even through libraries

but users find in hard to understand what the compiler did,

a tiny source change can make a big trace change

hard to fit in a debugger

security is a problem

pytorch 2.0

ml frameworks have two modes

Eager Mode

  • Preferred by users
  • Easier to use programming model
  • Easy to debug

a + b + c executes two calls to torch.add (if they are tensors)

no place to optimize, allows any kind of python, and any control flow

  • PyTorch is a primarily an eager mode framework

Graph Mode

  • Preferred by backends and framework builders
  • Easier to optimize with a compiler
  • Easier to do automated transformations

construct a graph with two add nodes and 3 input nodes, then execute the graph

easy to optimize, only graph nodes allowed, no control flow

Main optimization is fusing operations to avoid memory copies

how does the compiler fit

in Eager mode there is only a library - no compiler

if you have a matmul followed by an activation function, it is up to the developer to notice that the memory traffic is more expensive then the activation and its up to the developer to know there is another pytorch call (2000 different calls) which does the combined operation and it is up to the developer to change the code

if graph mode (compiler writers call this defered or late) the operations get recorded (not executed) and only get executed when we need the result

PyTorch’s Many Attempts at Graph Modes

torch.jit.trace

  1. Record + replay
  2. Unsound
  3. Can give incorrect results because it ignores Python part of program

torch. jit.script

  1. AOT parses Python into graph format
  2. Only works on ~45% of real world models
  3. High effort to “TorchScript” models
  4. PyTorch Models Are Not Static Graphs

PyTorch users write models where program graphs are impossible

Convert tensors to native Python types (x.item(), x.tolist(), int(x), etc)

Use other frameworks (numpy/xarray/etc) for part of their model

Data dependent Python control flow or other dynamism Exceptions, closures, generators, classes, etc

torch xla

defered execution. rather then do the graph operation, just save it and execute as late as possible

very slow, big performance cliffs

torch.compile(model) - converts a pytorch eager program to a graph

torch.dynamo - which dynamically captures Python code execution and creates a static computational graph.

torch.Inductor- compiler that optimimzes static computation graphs

dynamo

import torch
from typing import List

import torch._dynamo
torch._dynamo.config.suppress_errors = True

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my compiler() called with fx graph")
    gm.graph.print_tabular()
    return gm

@torch.compile(backend=my_compiler)
def toy_example(a,b):
    x = a / (torch.abs(a)+1)
    if b.sum() < 0:
        b  = b * -1
    return x *b

for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

output

def toy_example(a,b):
    x = a / (torch.abs(a)+1)
    if b.sum() < 0:
        b  = b * -1
    return x *b
opcode         name    target        args      
------       ------  ---------    -----------
placeholder    l_a_    L_a_         ()         
placeholder    l_b_    L_b_         ()         
call_function  abs_1   <abs>        (l_a_,)   
call_function  add     <add>        (abs_1, 1) 
call_function  x       <truediv>    (l_a_, add)
call_method    sum_1   sum          (l_b_,)   
call_function  lt      <lt>         (sum_1, 0) 
output         output  output       ((x, lt),)
------       ------  ---------    -----------
placeholder    l_b_    L_b_         ()           
placeholder    l_x_    L_x_         ()           
call_function  b       <mul>        (l_b_, -1)   
call_function  mul_1   <mul>        (l_x_, b)    
output         output  output       ((mul_1,),) 
------       ------  ---------    -----------        
placeholder    l_x_    L_x_         ()            
placeholder    l_b_    L_b_         ()            
call_function  mul     <mul>        (l_x_, l_b_)  
output         output  output       ((mul,),)     

implementation

python builds a frameObject (pointer to codeObject + arguments)

passes this to eval

codeObject allows for extra user data and for a user function to be called between the frameObject and eval

This makes it easy to add a custom JIT

split the function into two parts - the python part and the torch part

This is reused if the guards pass

pytorch example 1

def toy_example(a,b):
    x = a / (torch.abs(a)+1)
    if b.sum() < 0:
        b  = b * -1
    return x *b

for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

sometimes the sum is negative but not always

implementation













graph

opcode name target args
placeholder l_a_ L_a_ ()
placeholder l_b_ L_b_ ()
call_function abs_1 <built-in method abs of type object at 0x728736add8a0> (l_a_,)
call_function add (abs_1, 1)
call_function x (l_a_, add)
call_method sum_1 sum (l_b_,)
call_function lt (sum_1, 0)
output output output ((x, lt),)

code

def toy_example(a,b):
   (x,lt) = call1(a,b)
   if lt:
      f1(b,x)
   else:
      f2(b,x)

def f1(b, x):
   b  = b * -1
   return x *b

def f2(b,x):
   return x *b

guards:

check_tensor(L[‘a’], Tensor, torch.float32, size=[10], stride=[1]) check_tensor(L[‘b’], Tensor, torch.float32, size=[10], stride=[1])

walk the byte code again for f1

b  = b * -1
return x *b

TRACED GRAPH

pcode n ame t arget a rgs k wargs
placeholder l_b_ L_b_ () {}
placeholder l_x_ L_x_ () {}
call_function b (l_b_, -1) {}
call_function mul_1 (l_x_, b) {}
utput o utput o utput ( (mul_1,),) { }

other branch

guards: check_tensor(L['b'], torch.float32, size=[10], stride=[1]) 
        check_tensor(L['x'], torch.float32, size=[10], stride=[1]) 

TRACED GRAPH

pcode n ame t arget a rgs k warg
aceholder l_ x_ L_ x_ () {}
aceholder l_ b_ L_ b_ () {}
ll_function mu l <b uilt-in function mul> (l x, l_b_) {}
tput ou tput ou tput (( mul,),) {}

check_tensor(L[‘b’], torch.float32, size=[10], stride=[1]) check_tensor(L[‘x’], torch.float32, size=[10], stride=[1]) # return x *b # mp/ipykernel_1179164/26

linear traces

Dynamo removes all control flow, if/else, loops, exceptions

specializes (bakes in) all non-tensor objects (numbers, strings, classes )


@torch.compile
def fn(f,n):
  y = x ** 2
  if n >= 0:
    return (n +1)* y
  else:
    return x /y 

x = torch.randn(200)
fn(x,2)
def forward(l_x_: torch.Tensor):
   y = l_x_ ** 2
   mul 3*x
   return (mul,)

special cases

Trace integers symbolically

by default it specilizes on every integer in the graph but if a subsequent vall the value changes it traces symbolically but 0 or 1 are always speciaized

multiple traces

import torch
from typing import List

import torch._dynamo
torch._dynamo.config.suppress_errors = True

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my compiler() called with fx graph")
    gm.graph.print_tabular()
    return gm

@torch.compile(backend=my_compiler)
def toy_example(a,b):
    x = a / (torch.abs(a)+1)
    if b.sum() < 0:
        b  = b * -1
    return x *b

for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))


import torch
from typing import List

import torch._dynamo
torch._dynamo.config.suppress_errors = True

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my compiler() called with fx graph")
    gm.graph.print_tabular()
    return gm

implementation

pep 523, allows python function to see unevaluated frames, function + arguments

normally just calls the function

Back to top