Computer Laboratory

FieldReference.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2013 Jonathan Anderson
3  * All rights reserved.
4  *
5  * This software was developed by SRI International and the University of
6  * Cambridge Computer Laboratory under DARPA/AFRL contract (FA8750-10-C-0237)
7  * ("CTSRD"), as part of the DARPA CRASH research programme.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  * 1. Redistributions of source code must retain the above copyright
13  * notice, this list of conditions and the following disclaimer.
14  * 2. Redistributions in binary form must reproduce the above copyright
15  * notice, this list of conditions and the following disclaimer in the
16  * documentation and/or other materials provided with the distribution.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
19  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21  * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
22  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28  * SUCH DAMAGE.
29  */
30 
31 #include "Annotations.h"
32 #include "Debug.h"
33 #include "FieldReference.h"
34 #include "Instrumentation.h"
35 #include "Manifest.h"
36 #include "Names.h"
37 #include "Transition.h"
38 
39 #include <llvm/ADT/StringMap.h>
40 #include <llvm/IR/Constants.h>
41 #include <llvm/IR/Instructions.h>
42 #include <llvm/IR/Module.h>
43 #include <llvm/Support/raw_ostream.h> // TODO: tmp
44 
45 #include <map>
46 #include <set>
47 
48 using namespace llvm;
49 using std::map;
50 using std::set;
51 using std::string;
52 
53 
54 namespace tesla {
55 
56 char FieldReferenceInstrumenter::ID = 0;
57 raw_ostream& debug = debugs("tesla.instrumentation.field_assign");
58 
59 
62 public:
63  void AppendInstrumentation(const Automaton&, const TEquivalenceClass&);
64 
65  std::string CompleteFieldName() const {
66  return (StructTy->getName() + "." + FieldName).str();
67  }
68 
69  Function* getTarget() const { return InstrFn; }
70 
71  FieldInstrumentation(Function *InstrFn, Module& M, const StructType *T,
72  const StringRef FieldName, size_t FieldIndex)
73  : InstrFn(InstrFn), Exit(FindBlock("exit", *InstrFn)), M(M),
74  StructTy(T), FieldName(FieldName)
75  {
76  }
77 
78 private:
79  BasicBlock* NextInstrBlock(const Automaton*);
80 
81  Function *InstrFn;
82  BasicBlock *Exit;
83  Module& M;
84  const StructType *StructTy;
85  const StringRef FieldName;
86 
87  map<const Automaton*,BasicBlock*> NextInstr;
88 };
89 
90 
91 FieldReferenceInstrumenter::~FieldReferenceInstrumenter() {
92  for (auto& i : Instrumentation)
93  delete i.second;
94 }
95 
96 
97 bool FieldReferenceInstrumenter::runOnModule(Module &Mod) {
98  debug
99  << "===================================================================\n"
100  << __PRETTY_FUNCTION__ << "\n"
101  << "-------------------------------------------------------------------\n"
102  << "module: " << Mod.getModuleIdentifier() << "\n";
103 
104  this->Mod = &Mod;
105 
106  //
107  // First, find all struct fields that we want to instrument.
108  //
109  for (auto *Root : M.RootAutomata())
110  BuildInstrumentation(*M.FindAutomaton(Root->identifier()));
111 
112  debug << "instrumentation:\n";
113  for (auto& i : Instrumentation) {
114  debug << " " << i.getKey() << " -> ";
115  i.getValue()->getTarget()->getType()->print(debug);
116  debug << "\n";
117  }
118 
119  debug
120  << "-------------------------------------------------------------------\n"
121  << "looking for field references...\n"
122  ;
123 
124  //
125  // Then, iterate through all uses of the LLVM pointer annotation and look
126  // for structure accesses.
127  //
128  std::map<LoadInst*,FieldInstrumentation*> Loads;
129  std::map<StoreInst*,FieldInstrumentation*> Stores;
130 
131  //
132  // Look through all of the functions that start with llvm.ptr.annotation.
133  //
134  for (Function& Fn : Mod.getFunctionList()) {
135  if (!Fn.getName().startswith(LLVM_PTR_ANNOTATION))
136  continue;
137 
138  for (auto i = Fn.use_begin(); i != Fn.use_end(); i++) {
139  // We should be able to do some parsing of all annotations.
140  OwningPtr<PtrAnnotation> A(PtrAnnotation::Interpret(*i));
141  assert(A);
142 
143  // We only care about struct field annotations; ignore everything else.
144  auto *Annotation = dyn_cast<FieldAnnotation>(A.get());
145  if (!Annotation)
146  continue;
147 
148  // Not every struct field will have instrumentation defined for it.
149  auto Name = Annotation->completeFieldName();
150  auto *Instr = Instrumentation[Name];
151  if (Instr == NULL)
152  continue;
153 
154  for (User *U : *Annotation) {
155  auto *Cast = dyn_cast<CastInst>(U);
156  if (!Cast) {
157  U->dump();
158  panic("annotation user not a bitcast", false);
159  }
160 
161  for (auto k = Cast->use_begin(); k != Cast->use_end(); k++) {
162  if (auto *Load = dyn_cast<LoadInst>(*k))
163  Loads.insert(std::make_pair(Load, Instr));
164 
165  else if (auto *Store = dyn_cast<StoreInst>(*k))
166  Stores.insert(std::make_pair(Store, Instr));
167 
168  else {
169  k->dump();
170  panic("expected load or store with annotated value", false);
171  }
172  }
173  }
174  }
175  }
176 
177  for (auto i : Loads)
178  InstrumentLoad(i.first, i.second);
179 
180  for (auto i : Stores)
181  InstrumentStore(i.first, i.second);
182 
183  return true;
184 }
185 
186 
187 void FieldReferenceInstrumenter::BuildInstrumentation(const Automaton& A) {
188  for (auto& Transitions : A)
189  GetInstr(A, Transitions);
190 }
191 
192 
193 FieldInstrumentation* FieldReferenceInstrumenter::GetInstr(
194  const Automaton& A, const TEquivalenceClass& Trans) {
195 
196  auto *Head = dyn_cast<FieldAssignTransition>(*Trans.begin());
197  if (!Head) // ignore other kinds of transitions
198  return NULL;
199 
200  debug << Head->String() << "\n";
201  auto& Protobuf = Head->Assignment();
202  auto StructName = Protobuf.field().type();
203  auto FieldName = Protobuf.field().name();
204  string FullName = StructName + "." + FieldName;
205 
206  FieldInstrumentation *Instr;
207 
208  auto Existing = Instrumentation.find(FullName);
209  if (Existing != Instrumentation.end())
210  Instr = Existing->second;
211 
212  else {
213  StructType *T = Mod->getTypeByName("struct." + StructName);
214  if (!T) // ignore structs that aren't used by this module
215  return NULL;
216 
217  Function *InstrFn =
218  StructInstrumentation(*Mod, T, FieldName, Protobuf.field().index(), true,
219  SuppressDebugInstr);
220 
221  Instr = new FieldInstrumentation(InstrFn, *Mod, T,
222  FieldName, Protobuf.field().index());
223 
224  Instrumentation[FullName] = Instr;
225  }
226 
227  Instr->AppendInstrumentation(A, Trans);
228 
229  return Instr;
230 }
231 
232 
233 bool FieldReferenceInstrumenter::InstrumentLoad(
234  LoadInst*, FieldInstrumentation*) {
235 
236  //
237  // We don't actually instrument loads yet: we can't describe such references
238  // in the C-based automaton description language.
239  //
240 
241  return true;
242 }
243 
244 
245 bool FieldReferenceInstrumenter::InstrumentStore(
246  StoreInst *Store, FieldInstrumentation *Instr) {
247 
248  assert(Store != NULL);
249  assert(Instr != NULL);
250 
251  debug << "instrumenting: ";
252  Store->print(debug);
253  debug << "\n";
254 
255  assert(Store->getNumOperands() > 1);
256  Value *Val = Store->getOperand(0);
257  Value *Ptr = Store->getOperand(1);
258 
259  // Find the struct pointer this field was derived from.
260  Value *V = Ptr;
261  Value *StructPtr = NULL;
262 
263  do {
264  User *U = dyn_cast<User>(V);
265  if (!U) {
266  V->print(debug);
267  debug << " is not a User!\n";
268  panic("expected a User");
269  }
270 
271  assert(U->getNumOperands() > 0);
272  V = U->getOperand(0);
273 
274  auto *PointerTy = dyn_cast<PointerType>(V->getType());
275  if (PointerTy && PointerTy->getElementType()->isStructTy())
276  StructPtr = V;
277 
278  } while (StructPtr == NULL);
279 
280  std::vector<Value*> Args;
281  Args.push_back(StructPtr);
282  Args.push_back(Val);
283  Args.push_back(Ptr);
284 
285  IRBuilder<> Builder(Store);
286  Builder.CreateCall(Instr->getTarget(), Args);
287 
288  return true;
289 }
290 
291 
292 void FieldInstrumentation::AppendInstrumentation(
293  const Automaton& A, const TEquivalenceClass& Trans) {
294 
295  debug << "AppendInstrumentation\n";
296 
297  LLVMContext& Ctx = InstrFn->getContext();
298  auto *Head = dyn_cast<FieldAssignTransition>(*Trans.begin());
299  assert(Head);
300  auto& Protobuf = Head->Assignment();
301 
302  // The instrumentation function should be passed three parameters:
303  // the struct, the new value and a pointer to the field.
304  auto& Params = InstrFn->getArgumentList();
305  assert(Params.size() == 3);
306 
307  auto i = Params.begin();
308  llvm::Argument *Struct = &*i++;
309  llvm::Argument *NewValue = &*i++;
310  llvm::Argument *FieldPtr = &*i++;
311 
312  // We will definitely pass the structure's address to tesla_update_state().
313  // We may also pass the new value, if it's e.g. a pointer: see below.
314  SmallVector<Value*,2> KeyValues;
315  KeyValues.push_back(Struct);
316 
317  // Insert new instrumention before the current "end" block for the automaton.
318  auto *End = NextInstrBlock(&A);
319  auto *Instr = BasicBlock::Create(Ctx, Head->ShortLabel(), InstrFn, End);
320  End->replaceAllUsesWith(Instr);
321  IRBuilder<> Builder(Instr);
322 
323  // Are we assigning a constant value (in which case we should try to match
324  // it against a protobuf-supplied pattern) or a variable (in which case we
325  // should add it to the struct tesla_key)?
326  auto& ExpectedAssignment = Protobuf.value();
327 
328  switch (ExpectedAssignment.type()) {
329  case Argument::Constant: {
330  // Match the new value against the expected value or else ignore it.
331  IntegerType *ValueType = dyn_cast<IntegerType>(NewValue->getType());
332  if (!ValueType)
333  panic("NewValue not an integer type");
334 
335  auto *Match = BasicBlock::Create(Ctx, "match: " + Head->ShortLabel(),
336  InstrFn, Instr);
337  Instr->replaceAllUsesWith(Match);
338  IRBuilder<> Matcher(Match);
339 
340  auto *Const = ConstantInt::getSigned(ValueType, ExpectedAssignment.value());
341  Value *Expected;
342 
343  switch (Protobuf.operation()) {
344  case FieldAssignment::SimpleAssign:
345  Expected = Const;
346  break;
347 
348  case FieldAssignment::PlusEqual:
349  Expected = Matcher.CreateAdd(Matcher.CreateLoad(FieldPtr), Const);
350  break;
351 
352  case FieldAssignment::MinusEqual:
353  Expected = Matcher.CreateSub(Matcher.CreateLoad(FieldPtr), Const);
354  break;
355  }
356 
357  Matcher.CreateCondBr(Matcher.CreateICmpNE(NewValue, Expected), End, Instr);
358  break;
359  }
360 
361  case Argument::Variable:
362  KeyValues.push_back(NewValue);
363  break;
364 
365  case Argument::Any:
366  panic("'ANY' value should never be passed to struct field instrumentation");
367 
368  case Argument::Indirect:
369  panic("struct field instrumentation should not be passed indirect value");
370 
371  case Argument::Field:
372  panic("struct field instrumentation should not be passed struct field");
373  }
374 
375  Type* IntType = Type::getInt32Ty(Ctx);
376 
377  std::vector<Value*> Args;
378  Args.push_back(TeslaContext(A.getAssertion().context(), Ctx));
379  Args.push_back(ConstantInt::get(IntType, A.ID()));
380  Args.push_back(ConstructKey(Builder, M, KeyValues));
381  Args.push_back(Builder.CreateGlobalStringPtr(A.Name()));
382  Args.push_back(Builder.CreateGlobalStringPtr(A.String()));
383  Args.push_back(ConstructTransitions(Builder, M, Trans));
384 
385  Function *UpdateStateFn = FindStateUpdateFn(M, IntType);
386  assert(Args.size() == UpdateStateFn->arg_size());
387  Builder.CreateCall(UpdateStateFn, Args);
388  Builder.CreateBr(Exit);
389 }
390 
391 BasicBlock* FieldInstrumentation::NextInstrBlock(const Automaton *A) {
392  auto Existing = NextInstr.find(A);
393  if (Existing != NextInstr.end())
394  return Existing->second;
395 
396  auto& Ctx = M.getContext();
397  auto *Start = BasicBlock::Create(Ctx, A->Name(), InstrFn, Exit);
398  auto *End = BasicBlock::Create(Ctx, A->Name() + ":end", InstrFn, Exit);
399 
400  Exit->replaceAllUsesWith(End);
401 
402  IRBuilder<>(Start).CreateBr(End);
403  IRBuilder<>(End).CreateBr(Exit);
404 
405  NextInstr[A] = End;
406  return End;
407 }
408 
409 }